mistralrs_core/xlora_models/
quantized_llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::attention::SdpaParams;
7use crate::gguf::Content;
8use crate::lora::{get_lora_cfg, LinearLayerLike, LoraConfig, Merge, Ordering, QLoraLinear};
9use crate::pipeline::text_models_inputs_processor::FlashParams;
10use crate::utils::progress::NiceProgressBar;
11use candle_core::quantized::ggml_file;
12use candle_core::quantized::QMatMul;
13use candle_core::{DType, Device, Result, Tensor};
14use candle_nn::{Embedding, Module};
15use indicatif::MultiProgress;
16use mistralrs_quant::{MatMul, ShardedVarBuilder};
17use tqdm::Iter;
18use tracing::info;
19
20use crate::device_map::DeviceMapper;
21use crate::layers::{CausalMasker, QRmsNorm, RotaryEmbedding, Sdpa};
22use crate::pipeline::{extract_logits, Cache, EitherCache};
23
24use super::classifier::XLoraClassifier;
25use super::{verify_sanity_adapters, NonGranularState, ScalingsMaker, XLoraConfig};
26use crate::models::quantized_llama::PropsGGUF;
27use crate::utils::gguf_metadata::ContentMetadata;
28use crate::utils::model_config as ModelConfig;
29
30const MAX_SEQ_LEN: u32 = 4096;
31const SUPPORTED_LAYERS: [&str; 8] = [
32    "self_attn.q_proj",
33    "self_attn.k_proj",
34    "self_attn.v_proj",
35    "self_attn.o_proj",
36    "mlp.up_proj",
37    "mlp.down_proj",
38    "mlp.gate_proj",
39    "lm_head",
40];
41
42#[derive(Debug)]
43struct Mlp {
44    feed_forward_w1: QLoraLinear,
45    feed_forward_w2: QLoraLinear,
46    feed_forward_w3: QLoraLinear,
47}
48
49impl Mlp {
50    fn forward(
51        &self,
52        xs: &Tensor,
53        scalings: Option<Tensor>,
54        global_scaling_weight: f64,
55        is_scaling_pass: Option<f64>,
56    ) -> Result<Tensor> {
57        let w1 = self.feed_forward_w1.lora_forward(
58            xs,
59            scalings.clone(),
60            global_scaling_weight,
61            is_scaling_pass,
62        )?;
63        let w3 = self.feed_forward_w3.lora_forward(
64            xs,
65            scalings.clone(),
66            global_scaling_weight,
67            is_scaling_pass,
68        )?;
69        self.feed_forward_w2.lora_forward(
70            &(candle_nn::ops::silu(&w1)? * w3)?,
71            scalings.clone(),
72            global_scaling_weight,
73            is_scaling_pass,
74        )
75    }
76}
77
78#[derive(Debug)]
79enum MlpOrMoe {
80    Mlp(Mlp),
81    MoE {
82        n_expert_used: usize,
83        feed_forward_gate_inp: QMatMul,
84        experts: Vec<Mlp>,
85    },
86}
87
88impl MlpOrMoe {
89    fn forward(
90        &self,
91        xs: &Tensor,
92        scalings: Option<Tensor>,
93        global_scaling_weight: f64,
94        is_scaling_pass: Option<f64>,
95    ) -> Result<Tensor> {
96        match self {
97            Self::MoE {
98                feed_forward_gate_inp,
99                experts,
100                n_expert_used,
101            } => {
102                let (b_size, seq_len, hidden_dim) = xs.dims3()?;
103                let xs = xs.reshape(((), hidden_dim))?;
104                let router_logits = MatMul.qmatmul(&xs, feed_forward_gate_inp)?;
105                let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
106
107                // In order to extract topk, we extract the data from the tensor and manipulate it
108                // directly. Maybe we will want to use some custom ops instead at some point.
109                let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
110
111                // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
112                // top_x contains the row indexes to evaluate for each expert.
113                let mut top_x = vec![vec![]; experts.len()];
114                let mut selected_rws = vec![vec![]; experts.len()];
115                for (row_idx, rw) in routing_weights.iter().enumerate() {
116                    let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
117                    dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
118                    let mut sum_routing_weights = 0f32;
119                    for &expert_idx in dst.iter().take(*n_expert_used) {
120                        let expert_idx = expert_idx as usize;
121                        let routing_weight = rw[expert_idx];
122                        sum_routing_weights += routing_weight;
123                        top_x[expert_idx].push(row_idx as u32);
124                    }
125                    for &expert_idx in dst.iter().take(*n_expert_used) {
126                        let expert_idx = expert_idx as usize;
127                        let routing_weight = rw[expert_idx];
128                        selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
129                    }
130                }
131
132                // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
133                // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
134
135                let mut ys = xs.zeros_like()?;
136                for (expert_idx, expert_layer) in experts.iter().enumerate() {
137                    let top_x = &top_x[expert_idx];
138                    if top_x.is_empty() {
139                        continue;
140                    }
141                    let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
142                    let selected_rws =
143                        Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
144                            .reshape(((), 1))?;
145                    // Index the correct hidden states and compute the expert hidden state for
146                    // the current expert. We need to make sure to multiply the output hidden
147                    // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
148                    let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
149                    // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
150                    let current_hidden_states = expert_layer.forward(
151                        &current_state,
152                        scalings.clone(),
153                        global_scaling_weight,
154                        is_scaling_pass,
155                    )?;
156                    let current_hidden_states =
157                        current_hidden_states.broadcast_mul(&selected_rws)?;
158                    ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
159                }
160
161                let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
162                Ok(ys)
163            }
164            Self::Mlp(mlp) => {
165                mlp.forward(xs, scalings.clone(), global_scaling_weight, is_scaling_pass)
166            }
167        }
168    }
169}
170
171struct LayerWeights {
172    attention_wq: QLoraLinear,
173    attention_wk: QLoraLinear,
174    attention_wv: QLoraLinear,
175    attention_wo: QLoraLinear,
176    attention_norm: QRmsNorm,
177    mlp_or_moe: MlpOrMoe,
178    ffn_norm: QRmsNorm,
179    n_head: usize,
180    n_kv_head: usize,
181    head_dim: usize,
182    rotary: Arc<RotaryEmbedding>,
183    sdpa_params: SdpaParams,
184    dtype: DType,
185}
186
187impl LayerWeights {
188    #[allow(clippy::too_many_arguments)]
189    fn forward_attn(
190        &self,
191        x: &Tensor,
192        mask: &Option<Tensor>,
193        start_offsets: &[usize],
194        kv_cache: &mut Option<(Tensor, Tensor)>,
195        scalings: Option<Tensor>,
196        global_scaling_weight: f64,
197        is_scaling_pass: Option<f64>,
198        flash_params: &FlashParams,
199    ) -> Result<Tensor> {
200        let (b_sz, seq_len, n_embd) = x.dims3()?;
201        let q = self
202            .attention_wq
203            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
204            .to_dtype(self.dtype)?;
205        let k = self
206            .attention_wk
207            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
208            .to_dtype(self.dtype)?;
209        let v = self
210            .attention_wv
211            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
212            .to_dtype(self.dtype)?;
213
214        let (q, k, v) = if seq_len != 1 {
215            let q = q
216                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
217                .transpose(1, 2)?;
218            let k = k
219                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
220                .transpose(1, 2)?;
221            let v = v
222                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
223                .transpose(1, 2)?;
224            (q, k, v)
225        } else {
226            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
227            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
228            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
229            (q, k, v)
230        };
231
232        let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
233
234        let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
235
236        let y = Sdpa.run_attention(
237            &q,
238            &k,
239            &v,
240            mask.as_ref(),
241            Some(flash_params),
242            &self.sdpa_params,
243        )?;
244
245        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
246        let y = self.attention_wo.lora_forward(
247            &y.to_dtype(x.dtype())?,
248            scalings.clone(),
249            global_scaling_weight,
250            is_scaling_pass,
251        )?;
252        Ok(y)
253    }
254}
255
256pub struct ModelWeights {
257    tok_embeddings: Embedding,
258    layers: Vec<LayerWeights>,
259    norm: QRmsNorm,
260    output: QLoraLinear,
261    pub device: Device,
262    pub cache: EitherCache,
263    xlora_classifier: Option<XLoraClassifier>,
264    pub max_seq_len: usize,
265    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
266    dtype: DType,
267}
268
269impl ModelConfig::FromAdapterGGML for ModelWeights {
270    fn from_ggml(
271        mut ct: ggml_file::Content,
272        gqa: usize,
273        lora_config: &[((String, String), LoraConfig)],
274        vb: &ShardedVarBuilder,
275        ordering: &Ordering,
276        xlora_config: Option<XLoraConfig>,
277        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
278        dtype: DType,
279    ) -> Result<Self> {
280        let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
281        let rotary = RotaryEmbedding::new_partial(
282            10000.,
283            ct.hparams.n_rot as usize,
284            MAX_SEQ_LEN as usize,
285            &ct.device,
286            false,
287            dtype,
288        )?;
289        let tok_embeddings = ct.remove("tok_embeddings.weight")?;
290        let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
291        let norm = QRmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
292        let output = ct.remove("output.weight")?;
293        let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
294        let mut count = 0;
295        for layer_idx in 0..ct.hparams.n_layer {
296            let prefix = format!("layers.{layer_idx}");
297            let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
298            let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
299            let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
300            let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
301            let mlp_or_moe = {
302                let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
303                let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
304                let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
305                let cfg_w1 = get_lora_cfg(&feed_forward_w1);
306                let cfg_w2 = get_lora_cfg(&feed_forward_w2);
307                let cfg_w3 = get_lora_cfg(&feed_forward_w3);
308                MlpOrMoe::Mlp(Mlp {
309                    feed_forward_w1: QLoraLinear::new(
310                        QMatMul::from_qtensor(feed_forward_w1)?,
311                        &cfg_w1,
312                        lora_config,
313                        vb,
314                        ordering,
315                        format!("model.layers.{layer_idx}.mlp.gate_proj"),
316                        &mut count,
317                        preload_adapters,
318                    )?,
319                    feed_forward_w2: QLoraLinear::new(
320                        QMatMul::from_qtensor(feed_forward_w2)?,
321                        &cfg_w2,
322                        lora_config,
323                        vb,
324                        ordering,
325                        format!("model.layers.{layer_idx}.mlp.down_proj"),
326                        &mut count,
327                        preload_adapters,
328                    )?,
329                    feed_forward_w3: QLoraLinear::new(
330                        QMatMul::from_qtensor(feed_forward_w3)?,
331                        &cfg_w3,
332                        lora_config,
333                        vb,
334                        ordering,
335                        format!("model.layers.{layer_idx}.mlp.up_proj"),
336                        &mut count,
337                        preload_adapters,
338                    )?,
339                })
340            };
341            let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
342            let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
343            let cfgq = get_lora_cfg(&attention_wq);
344            let cfgk = get_lora_cfg(&attention_wk);
345            let cfgv = get_lora_cfg(&attention_wv);
346            let cfgo = get_lora_cfg(&attention_wo);
347            let n_kv_head = ct.hparams.n_head as usize / gqa;
348            layers.push(LayerWeights {
349                attention_wq: QLoraLinear::new(
350                    QMatMul::from_qtensor(attention_wq)?,
351                    &cfgq,
352                    lora_config,
353                    vb,
354                    ordering,
355                    format!("model.layers.{layer_idx}.self_attn.q_proj"),
356                    &mut count,
357                    preload_adapters,
358                )?,
359                attention_wk: QLoraLinear::new(
360                    QMatMul::from_qtensor(attention_wk)?,
361                    &cfgk,
362                    lora_config,
363                    vb,
364                    ordering,
365                    format!("model.layers.{layer_idx}.self_attn.k_proj"),
366                    &mut count,
367                    preload_adapters,
368                )?,
369                attention_wv: QLoraLinear::new(
370                    QMatMul::from_qtensor(attention_wv)?,
371                    &cfgv,
372                    lora_config,
373                    vb,
374                    ordering,
375                    format!("model.layers.{layer_idx}.self_attn.v_proj"),
376                    &mut count,
377                    preload_adapters,
378                )?,
379                attention_wo: QLoraLinear::new(
380                    QMatMul::from_qtensor(attention_wo)?,
381                    &cfgo,
382                    lora_config,
383                    vb,
384                    ordering,
385                    format!("model.layers.{layer_idx}.self_attn.o_proj"),
386                    &mut count,
387                    preload_adapters,
388                )?,
389                attention_norm: QRmsNorm::new(attention_norm, 1e-5)?,
390                mlp_or_moe,
391                ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?,
392                n_head: ct.hparams.n_head as usize,
393                n_kv_head: ct.hparams.n_head as usize / gqa,
394                head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
395                rotary: rotary.clone().into(),
396                sdpa_params: SdpaParams {
397                    n_kv_groups: ct.hparams.n_head as usize / n_kv_head,
398                    use_flash_attn: false,
399                    softcap: None,
400                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
401                    sliding_window: None,
402                },
403                dtype,
404            })
405        }
406        if xlora_config.is_none() && preload_adapters.is_none() {
407            // We are now a LoRA model so we must merge the weights
408            info!("Merging LoRA adapters.");
409            for layer in layers.iter_mut().tqdm() {
410                layer.attention_wk.merge_weights()?;
411                layer.attention_wo.merge_weights()?;
412                layer.attention_wq.merge_weights()?;
413                layer.attention_wv.merge_weights()?;
414                match &mut layer.mlp_or_moe {
415                    MlpOrMoe::Mlp(ref mut m) => {
416                        m.feed_forward_w1.merge_weights()?;
417                        m.feed_forward_w2.merge_weights()?;
418                        m.feed_forward_w3.merge_weights()?;
419                    }
420                    MlpOrMoe::MoE {
421                        n_expert_used: _,
422                        feed_forward_gate_inp: _,
423                        experts,
424                    } => {
425                        for expert in experts {
426                            expert.feed_forward_w1.merge_weights()?;
427                            expert.feed_forward_w2.merge_weights()?;
428                            expert.feed_forward_w3.merge_weights()?;
429                        }
430                    }
431                }
432            }
433        }
434        let output_cfg = get_lora_cfg(&output);
435        let output = QLoraLinear::new(
436            QMatMul::from_qtensor(output)?,
437            &output_cfg,
438            lora_config,
439            vb,
440            ordering,
441            "lm_head".to_string(),
442            &mut count,
443            preload_adapters,
444        )?;
445        if xlora_config.is_some() && output.is_lora() {
446            // This is why we can pass dummy values (..., None, 1.0, None)?
447            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
448        }
449        Ok(Self {
450            tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
451            layers,
452            norm,
453            output,
454            device: ct.device.clone(),
455            cache: EitherCache::Full(Cache::new(ct.hparams.n_layer as usize, true)),
456            xlora_classifier: xlora_config.map(|xlora_config| {
457                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
458                    .unwrap()
459            }),
460            max_seq_len: MAX_SEQ_LEN as usize, // Cannot determine from ggml.
461            mapper: None,
462            dtype,
463        })
464    }
465}
466
467impl ModelConfig::FromAdapterGGUF for ModelWeights {
468    #[allow(clippy::too_many_arguments)]
469    fn from_gguf<R: std::io::Seek + std::io::Read>(
470        mut ct: Content<'_, R>,
471        device: &Device,
472        lora_config: &[((String, String), LoraConfig)],
473        vb: &ShardedVarBuilder,
474        ordering: &Ordering,
475        xlora_config: Option<XLoraConfig>,
476        mapper: Box<dyn DeviceMapper + Send + Sync>,
477        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
478        dtype: DType,
479    ) -> Result<Self> {
480        verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
481
482        // Parameter extraction from metadata.
483        let metadata = ContentMetadata {
484            path_prefix: "llama",
485            metadata: ct.get_metadata(),
486        };
487        let PropsGGUF {
488            n_expert,
489            n_expert_used,
490            head_count,
491            head_count_kv,
492            block_count,
493            embedding_length,
494            rope_dim,
495            rms_norm_eps,
496            max_seq_len,
497            rope_freq_base,
498            key_length,
499            value_length,
500        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
501
502        let head_dim = key_length;
503        if key_length != value_length {
504            candle_core::bail!(
505                "Expected key_length == value_length, got {key_length} != {value_length}"
506            );
507        }
508
509        let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
510        let tok_embeddings = qtok_embeddings.dequantize(device)?;
511        let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
512        let output = if !ct.has_tensor("output.weight") {
513            ct.tensor("token_embd.weight", device)?
514        } else {
515            ct.tensor("output.weight", device)?
516        };
517        let mut layers = Vec::with_capacity(block_count);
518        let mut count = 0;
519
520        let mut ropes = HashMap::new();
521        for layer_idx in 0..block_count {
522            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
523            ropes.insert(
524                device.location(),
525                Arc::new(RotaryEmbedding::new(
526                    rope_freq_base,
527                    rope_dim,
528                    max_seq_len,
529                    device,
530                    false,
531                    dtype,
532                )?),
533            );
534        }
535
536        for layer_idx in NiceProgressBar::<_, 'b'>(
537            0..block_count,
538            "Loading repeating layers",
539            &MultiProgress::new(),
540        ) {
541            let prefix = format!("blk.{layer_idx}");
542            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
543            let rotary = ropes
544                .get(&device.location())
545                .expect("No RoPE for device location!")
546                .clone();
547
548            let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
549            let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
550            let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
551            let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
552            let mlp_or_moe = if n_expert <= 1 {
553                let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
554                let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
555                let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
556                let cfg_w1 = get_lora_cfg(&feed_forward_w1);
557                let cfg_w2 = get_lora_cfg(&feed_forward_w2);
558                let cfg_w3 = get_lora_cfg(&feed_forward_w3);
559                MlpOrMoe::Mlp(Mlp {
560                    feed_forward_w1: QLoraLinear::new(
561                        QMatMul::from_qtensor(feed_forward_w1)?,
562                        &cfg_w1,
563                        lora_config,
564                        vb,
565                        ordering,
566                        format!("model.layers.{layer_idx}.mlp.gate_proj"),
567                        &mut count,
568                        preload_adapters,
569                    )?,
570                    feed_forward_w2: QLoraLinear::new(
571                        QMatMul::from_qtensor(feed_forward_w2)?,
572                        &cfg_w2,
573                        lora_config,
574                        vb,
575                        ordering,
576                        format!("model.layers.{layer_idx}.mlp.down_proj"),
577                        &mut count,
578                        preload_adapters,
579                    )?,
580                    feed_forward_w3: QLoraLinear::new(
581                        QMatMul::from_qtensor(feed_forward_w3)?,
582                        &cfg_w3,
583                        lora_config,
584                        vb,
585                        ordering,
586                        format!("model.layers.{layer_idx}.mlp.up_proj"),
587                        &mut count,
588                        preload_adapters,
589                    )?,
590                })
591            } else {
592                let feed_forward_gate_inp =
593                    ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?;
594                let mut experts = Vec::with_capacity(n_expert);
595                for i in 0..n_expert {
596                    let feed_forward_w1 =
597                        ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?;
598                    let feed_forward_w2 =
599                        ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?;
600                    let feed_forward_w3 =
601                        ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?;
602                    let cfg_w1 = get_lora_cfg(&feed_forward_w1);
603                    let cfg_w2 = get_lora_cfg(&feed_forward_w2);
604                    let cfg_w3 = get_lora_cfg(&feed_forward_w3);
605                    experts.push(Mlp {
606                        feed_forward_w1: QLoraLinear::new(
607                            QMatMul::from_qtensor(feed_forward_w1)?,
608                            &cfg_w1,
609                            lora_config,
610                            vb,
611                            ordering,
612                            format!("model.layers.{layer_idx}.mlp.gate_proj.{i}"),
613                            &mut count,
614                            preload_adapters,
615                        )?,
616                        feed_forward_w2: QLoraLinear::new(
617                            QMatMul::from_qtensor(feed_forward_w2)?,
618                            &cfg_w2,
619                            lora_config,
620                            vb,
621                            ordering,
622                            format!("model.layers.{layer_idx}.mlp.down_proj.{i}"),
623                            &mut count,
624                            preload_adapters,
625                        )?,
626                        feed_forward_w3: QLoraLinear::new(
627                            QMatMul::from_qtensor(feed_forward_w3)?,
628                            &cfg_w3,
629                            lora_config,
630                            vb,
631                            ordering,
632                            format!("model.layers.{layer_idx}.mlp.up_proj.{i}"),
633                            &mut count,
634                            preload_adapters,
635                        )?,
636                    })
637                }
638                MlpOrMoe::MoE {
639                    n_expert_used,
640                    feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
641                    experts,
642                }
643            };
644            let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
645            let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
646            let cfgq = get_lora_cfg(&attention_wq);
647            let cfgk = get_lora_cfg(&attention_wk);
648            let cfgv = get_lora_cfg(&attention_wv);
649            let cfgo = get_lora_cfg(&attention_wo);
650            layers.push(LayerWeights {
651                attention_wq: QLoraLinear::new(
652                    QMatMul::from_qtensor(attention_wq)?,
653                    &cfgq,
654                    lora_config,
655                    vb,
656                    ordering,
657                    format!("model.layers.{layer_idx}.self_attn.q_proj"),
658                    &mut count,
659                    preload_adapters,
660                )?,
661                attention_wk: QLoraLinear::new(
662                    QMatMul::from_qtensor(attention_wk)?,
663                    &cfgk,
664                    lora_config,
665                    vb,
666                    ordering,
667                    format!("model.layers.{layer_idx}.self_attn.k_proj"),
668                    &mut count,
669                    preload_adapters,
670                )?,
671                attention_wv: QLoraLinear::new(
672                    QMatMul::from_qtensor(attention_wv)?,
673                    &cfgv,
674                    lora_config,
675                    vb,
676                    ordering,
677                    format!("model.layers.{layer_idx}.self_attn.v_proj"),
678                    &mut count,
679                    preload_adapters,
680                )?,
681                attention_wo: QLoraLinear::new(
682                    QMatMul::from_qtensor(attention_wo)?,
683                    &cfgo,
684                    lora_config,
685                    vb,
686                    ordering,
687                    format!("model.layers.{layer_idx}.self_attn.o_proj"),
688                    &mut count,
689                    preload_adapters,
690                )?,
691                attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
692                mlp_or_moe,
693                ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
694                n_head: head_count,
695                n_kv_head: head_count_kv,
696                head_dim: embedding_length / head_count,
697                rotary: rotary.clone(),
698                sdpa_params: SdpaParams {
699                    n_kv_groups: head_count / head_count_kv,
700                    use_flash_attn: false,
701                    softcap: None,
702                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
703                    sliding_window: None,
704                },
705                dtype,
706            })
707        }
708        if xlora_config.is_none() && preload_adapters.is_none() {
709            // We are now a LoRA model so we must merge the weights
710            info!("Merging LoRA adapters.");
711            for layer in layers.iter_mut().tqdm() {
712                layer.attention_wk.merge_weights()?;
713                layer.attention_wo.merge_weights()?;
714                layer.attention_wq.merge_weights()?;
715                layer.attention_wv.merge_weights()?;
716                match &mut layer.mlp_or_moe {
717                    MlpOrMoe::Mlp(ref mut m) => {
718                        m.feed_forward_w1.merge_weights()?;
719                        m.feed_forward_w2.merge_weights()?;
720                        m.feed_forward_w3.merge_weights()?;
721                    }
722                    MlpOrMoe::MoE {
723                        n_expert_used: _,
724                        feed_forward_gate_inp: _,
725                        experts,
726                    } => {
727                        for expert in experts {
728                            expert.feed_forward_w1.merge_weights()?;
729                            expert.feed_forward_w2.merge_weights()?;
730                            expert.feed_forward_w3.merge_weights()?;
731                        }
732                    }
733                }
734            }
735        }
736        let output_cfg = get_lora_cfg(&output);
737        let output = QLoraLinear::new(
738            QMatMul::from_qtensor(output)?,
739            &output_cfg,
740            lora_config,
741            vb,
742            ordering,
743            "lm_head".to_string(),
744            &mut count,
745            preload_adapters,
746        )?;
747        if xlora_config.is_some() && output.is_lora() {
748            // This is why we can pass dummy values (..., None, 1.0, None)?
749            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
750        }
751        Ok(Self {
752            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
753            layers,
754            norm,
755            output,
756            device: device.clone(),
757            cache: EitherCache::Full(Cache::new(block_count, true)),
758            xlora_classifier: xlora_config.map(|xlora_config| {
759                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
760                    .unwrap()
761            }),
762            max_seq_len,
763            mapper: Some(mapper),
764            dtype,
765        })
766    }
767}
768
769impl ModelWeights {
770    #[allow(clippy::too_many_arguments)]
771    fn inner_forward(
772        &self,
773        x: &Tensor,
774        start_offsets: &[usize],
775        scalings: Option<Tensor>,
776        is_full_pass: bool,
777        no_kv_cache: bool,
778        is_scaling_pass: Option<f64>,
779        flash_params: &FlashParams,
780    ) -> Result<Tensor> {
781        let mut layer_in = self.tok_embeddings.forward(x)?;
782        let mut cache = if is_full_pass {
783            if no_kv_cache {
784                let mut new_cache = Vec::new();
785                for _ in 0..self.cache.full().xlora_lock().len() {
786                    new_cache.push(None);
787                }
788
789                self.cache.full().xlora_lock().clone_from(&new_cache);
790            }
791            self.cache.full().xlora_lock()
792        } else {
793            self.cache.full().lock()
794        };
795        let mask =
796            CausalMasker.make_causal_mask_matrix(x, &*cache, self.dtype, self.layers[0].n_head)?;
797        for (i, layer) in self.layers.iter().enumerate() {
798            if let Some(ref mapper) = self.mapper {
799                layer_in = mapper.map(layer_in, i)?;
800            }
801            let x = layer_in;
802            let residual = &x;
803            let x = layer.attention_norm.forward(&x)?;
804            let attn = layer.forward_attn(
805                &x,
806                &mask.as_ref().map(|m| m.to_device(x.device()).unwrap()),
807                start_offsets,
808                &mut cache[i],
809                scalings.clone(),
810                self.xlora_classifier
811                    .as_ref()
812                    .map(|classifier| classifier.get_global_scaling_weight())
813                    .unwrap_or(1.0),
814                is_scaling_pass,
815                flash_params,
816            )?;
817            let x = (attn + residual)?;
818
819            // MLP
820            let residual = &x;
821            let x = layer.ffn_norm.forward(&x)?;
822            let x = layer.mlp_or_moe.forward(
823                &x,
824                scalings.clone(),
825                self.xlora_classifier
826                    .as_ref()
827                    .map(|classifier| classifier.get_global_scaling_weight())
828                    .unwrap_or(1.0),
829                is_scaling_pass,
830            )?;
831            let x = (x + residual)?;
832            layer_in = x;
833        }
834        let layer_in = layer_in.to_device(&self.device)?;
835        self.norm.forward(&layer_in)
836    }
837
838    #[allow(clippy::too_many_arguments)]
839    pub fn forward(
840        &self,
841        input_ids: &Tensor,
842        input_ids_full: &Tensor,
843        seqlen_offsets: &[usize],
844        seqlen_offsets_full: &[usize],
845        no_kv_cache: bool,
846        non_granular_state: &Option<NonGranularState>,
847        context_lens: Vec<(usize, usize)>,
848        flash_params: &FlashParams,
849        flash_params_full: &FlashParams,
850    ) -> Result<Tensor> {
851        if self.xlora_classifier.is_some() {
852            let scalings = self.get_scalings(
853                input_ids,
854                input_ids_full,
855                seqlen_offsets,
856                seqlen_offsets_full,
857                no_kv_cache,
858                non_granular_state,
859                &vec![usize::MAX; context_lens.len()],
860                flash_params,
861                flash_params_full,
862            )?;
863
864            if no_kv_cache {
865                extract_logits(
866                    &self.output.lora_forward(
867                        &self
868                            .inner_forward(
869                                input_ids_full,
870                                seqlen_offsets_full,
871                                Some(scalings),
872                                true,
873                                no_kv_cache,
874                                None,
875                                flash_params_full,
876                            )?
877                            .contiguous()?,
878                        None,
879                        1.0,
880                        None,
881                    )?,
882                    context_lens,
883                )
884            } else {
885                // is_full_pass=true is ok because no_kv_cache=false
886                extract_logits(
887                    &self.output.lora_forward(
888                        &self
889                            .inner_forward(
890                                input_ids,
891                                seqlen_offsets,
892                                Some(scalings),
893                                true,
894                                no_kv_cache,
895                                None,
896                                flash_params,
897                            )?
898                            .contiguous()?,
899                        None,
900                        1.0,
901                        None,
902                    )?,
903                    context_lens,
904                )
905            }
906        } else {
907            extract_logits(
908                &self.output.lora_forward(
909                    &self
910                        .inner_forward(
911                            input_ids,
912                            seqlen_offsets,
913                            None,
914                            false,
915                            no_kv_cache,
916                            None,
917                            flash_params,
918                        )?
919                        .contiguous()?,
920                    None,
921                    1.0,
922                    None,
923                )?,
924                context_lens,
925            )
926        }
927    }
928}
929
930impl ScalingsMaker for ModelWeights {
931    fn dtype(&self) -> DType {
932        DType::F32 // for dummy scalings
933    }
934    fn get_cache(&self) -> &EitherCache {
935        &self.cache
936    }
937    fn get_classifier(&self) -> &XLoraClassifier {
938        self.xlora_classifier.as_ref().unwrap()
939    }
940    fn forward(
941        &self,
942        input_ids: &Tensor,
943        seqlen_offsets: &[usize],
944        scalings: Tensor,
945        is_full_pass: bool,
946        no_kv_cache: bool,
947        is_scaling_pass: Option<f64>,
948        _context_lens: &[usize],
949        flash_params: &FlashParams,
950    ) -> Result<Tensor> {
951        self.inner_forward(
952            input_ids,
953            seqlen_offsets,
954            Some(scalings),
955            is_full_pass,
956            no_kv_cache,
957            is_scaling_pass,
958            flash_params,
959        )
960    }
961}