1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::attention::SdpaParams;
7use crate::gguf::Content;
8use crate::lora::{
9 get_lora_cfg, AdapterSwapper, LinearLayerLike, LoraConfig, Merge, Ordering, QLoraLinear,
10};
11use crate::pipeline::text_models_inputs_processor::FlashParams;
12use crate::utils::progress::NiceProgressBar;
13use candle_core::quantized::ggml_file;
14use candle_core::quantized::QMatMul;
15use candle_core::{DType, Device, Result, Tensor};
16use candle_nn::{Embedding, Module};
17use indicatif::MultiProgress;
18use mistralrs_quant::{MatMul, ShardedVarBuilder};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::device_map::DeviceMapper;
23use crate::layers::{CausalMasker, QRmsNorm, RotaryEmbedding, Sdpa};
24use crate::pipeline::{extract_logits, Cache, EitherCache};
25
26use super::classifier::XLoraClassifier;
27use super::{verify_sanity_adapters, NonGranularState, ScalingsMaker, XLoraConfig};
28use crate::models::quantized_llama::PropsGGUF;
29use crate::utils::gguf_metadata::ContentMetadata;
30use crate::utils::model_config as ModelConfig;
31
32const MAX_SEQ_LEN: u32 = 4096;
33const SUPPORTED_LAYERS: [&str; 8] = [
34 "self_attn.q_proj",
35 "self_attn.k_proj",
36 "self_attn.v_proj",
37 "self_attn.o_proj",
38 "mlp.up_proj",
39 "mlp.down_proj",
40 "mlp.gate_proj",
41 "lm_head",
42];
43
44#[derive(Debug)]
45struct Mlp {
46 feed_forward_w1: QLoraLinear,
47 feed_forward_w2: QLoraLinear,
48 feed_forward_w3: QLoraLinear,
49}
50
51impl Mlp {
52 fn forward(
53 &self,
54 xs: &Tensor,
55 scalings: Option<Tensor>,
56 global_scaling_weight: f64,
57 is_scaling_pass: Option<f64>,
58 ) -> Result<Tensor> {
59 let w1 = self.feed_forward_w1.lora_forward(
60 xs,
61 scalings.clone(),
62 global_scaling_weight,
63 is_scaling_pass,
64 )?;
65 let w3 = self.feed_forward_w3.lora_forward(
66 xs,
67 scalings.clone(),
68 global_scaling_weight,
69 is_scaling_pass,
70 )?;
71 self.feed_forward_w2.lora_forward(
72 &(candle_nn::ops::silu(&w1)? * w3)?,
73 scalings.clone(),
74 global_scaling_weight,
75 is_scaling_pass,
76 )
77 }
78}
79
80#[derive(Debug)]
81enum MlpOrMoe {
82 Mlp(Mlp),
83 MoE {
84 n_expert_used: usize,
85 feed_forward_gate_inp: QMatMul,
86 experts: Vec<Mlp>,
87 },
88}
89
90impl MlpOrMoe {
91 fn forward(
92 &self,
93 xs: &Tensor,
94 scalings: Option<Tensor>,
95 global_scaling_weight: f64,
96 is_scaling_pass: Option<f64>,
97 ) -> Result<Tensor> {
98 match self {
99 Self::MoE {
100 feed_forward_gate_inp,
101 experts,
102 n_expert_used,
103 } => {
104 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
105 let xs = xs.reshape(((), hidden_dim))?;
106 let router_logits = MatMul.qmatmul(&xs, feed_forward_gate_inp)?;
107 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
108
109 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
112
113 let mut top_x = vec![vec![]; experts.len()];
116 let mut selected_rws = vec![vec![]; experts.len()];
117 for (row_idx, rw) in routing_weights.iter().enumerate() {
118 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
119 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
120 let mut sum_routing_weights = 0f32;
121 for &expert_idx in dst.iter().take(*n_expert_used) {
122 let expert_idx = expert_idx as usize;
123 let routing_weight = rw[expert_idx];
124 sum_routing_weights += routing_weight;
125 top_x[expert_idx].push(row_idx as u32);
126 }
127 for &expert_idx in dst.iter().take(*n_expert_used) {
128 let expert_idx = expert_idx as usize;
129 let routing_weight = rw[expert_idx];
130 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
131 }
132 }
133
134 let mut ys = xs.zeros_like()?;
138 for (expert_idx, expert_layer) in experts.iter().enumerate() {
139 let top_x = &top_x[expert_idx];
140 if top_x.is_empty() {
141 continue;
142 }
143 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
144 let selected_rws =
145 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
146 .reshape(((), 1))?;
147 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
151 let current_hidden_states = expert_layer.forward(
153 ¤t_state,
154 scalings.clone(),
155 global_scaling_weight,
156 is_scaling_pass,
157 )?;
158 let current_hidden_states =
159 current_hidden_states.broadcast_mul(&selected_rws)?;
160 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
161 }
162
163 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
164 Ok(ys)
165 }
166 Self::Mlp(mlp) => {
167 mlp.forward(xs, scalings.clone(), global_scaling_weight, is_scaling_pass)
168 }
169 }
170 }
171}
172
173struct LayerWeights {
174 attention_wq: QLoraLinear,
175 attention_wk: QLoraLinear,
176 attention_wv: QLoraLinear,
177 attention_wo: QLoraLinear,
178 attention_norm: QRmsNorm,
179 mlp_or_moe: MlpOrMoe,
180 ffn_norm: QRmsNorm,
181 n_head: usize,
182 n_kv_head: usize,
183 head_dim: usize,
184 rotary: Arc<RotaryEmbedding>,
185 sdpa_params: SdpaParams,
186 dtype: DType,
187}
188
189impl LayerWeights {
190 #[allow(clippy::too_many_arguments)]
191 fn forward_attn(
192 &self,
193 x: &Tensor,
194 mask: &Option<Tensor>,
195 start_offsets: &[usize],
196 kv_cache: &mut Option<(Tensor, Tensor)>,
197 scalings: Option<Tensor>,
198 global_scaling_weight: f64,
199 is_scaling_pass: Option<f64>,
200 flash_params: &FlashParams,
201 ) -> Result<Tensor> {
202 let (b_sz, seq_len, n_embd) = x.dims3()?;
203 let q = self
204 .attention_wq
205 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
206 .to_dtype(self.dtype)?;
207 let k = self
208 .attention_wk
209 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
210 .to_dtype(self.dtype)?;
211 let v = self
212 .attention_wv
213 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
214 .to_dtype(self.dtype)?;
215
216 let (q, k, v) = if seq_len != 1 {
217 let q = q
218 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
219 .transpose(1, 2)?;
220 let k = k
221 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
222 .transpose(1, 2)?;
223 let v = v
224 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
225 .transpose(1, 2)?;
226 (q, k, v)
227 } else {
228 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
229 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
230 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
231 (q, k, v)
232 };
233
234 let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
235
236 let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
237
238 let y = Sdpa.run_attention(
239 &q,
240 &k,
241 &v,
242 mask.as_ref(),
243 Some(flash_params),
244 &self.sdpa_params,
245 )?;
246
247 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
248 let y = self.attention_wo.lora_forward(
249 &y.to_dtype(x.dtype())?,
250 scalings.clone(),
251 global_scaling_weight,
252 is_scaling_pass,
253 )?;
254 Ok(y)
255 }
256}
257
258pub struct ModelWeights {
259 tok_embeddings: Embedding,
260 layers: Vec<LayerWeights>,
261 norm: QRmsNorm,
262 output: QLoraLinear,
263 pub device: Device,
264 pub cache: EitherCache,
265 xlora_classifier: Option<XLoraClassifier>,
266 pub max_seq_len: usize,
267 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
268 dtype: DType,
269}
270
271impl ModelConfig::FromAdapterGGML for ModelWeights {
272 fn from_ggml(
273 mut ct: ggml_file::Content,
274 gqa: usize,
275 lora_config: &[((String, String), LoraConfig)],
276 vb: &ShardedVarBuilder,
277 ordering: &Ordering,
278 xlora_config: Option<XLoraConfig>,
279 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
280 dtype: DType,
281 ) -> Result<Self> {
282 let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
283 let rotary = RotaryEmbedding::new_partial(
284 10000.,
285 ct.hparams.n_rot as usize,
286 MAX_SEQ_LEN as usize,
287 &ct.device,
288 false,
289 dtype,
290 )?;
291 let tok_embeddings = ct.remove("tok_embeddings.weight")?;
292 let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
293 let norm = QRmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
294 let output = ct.remove("output.weight")?;
295 let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
296 let mut count = 0;
297 for layer_idx in 0..ct.hparams.n_layer {
298 let prefix = format!("layers.{layer_idx}");
299 let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
300 let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
301 let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
302 let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
303 let mlp_or_moe = {
304 let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
305 let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
306 let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
307 let cfg_w1 = get_lora_cfg(&feed_forward_w1);
308 let cfg_w2 = get_lora_cfg(&feed_forward_w2);
309 let cfg_w3 = get_lora_cfg(&feed_forward_w3);
310 MlpOrMoe::Mlp(Mlp {
311 feed_forward_w1: QLoraLinear::new(
312 QMatMul::from_qtensor(feed_forward_w1)?,
313 &cfg_w1,
314 lora_config,
315 vb,
316 ordering,
317 format!("model.layers.{layer_idx}.mlp.gate_proj"),
318 &mut count,
319 preload_adapters,
320 )?,
321 feed_forward_w2: QLoraLinear::new(
322 QMatMul::from_qtensor(feed_forward_w2)?,
323 &cfg_w2,
324 lora_config,
325 vb,
326 ordering,
327 format!("model.layers.{layer_idx}.mlp.down_proj"),
328 &mut count,
329 preload_adapters,
330 )?,
331 feed_forward_w3: QLoraLinear::new(
332 QMatMul::from_qtensor(feed_forward_w3)?,
333 &cfg_w3,
334 lora_config,
335 vb,
336 ordering,
337 format!("model.layers.{layer_idx}.mlp.up_proj"),
338 &mut count,
339 preload_adapters,
340 )?,
341 })
342 };
343 let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
344 let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
345 let cfgq = get_lora_cfg(&attention_wq);
346 let cfgk = get_lora_cfg(&attention_wk);
347 let cfgv = get_lora_cfg(&attention_wv);
348 let cfgo = get_lora_cfg(&attention_wo);
349 let n_kv_head = ct.hparams.n_head as usize / gqa;
350 layers.push(LayerWeights {
351 attention_wq: QLoraLinear::new(
352 QMatMul::from_qtensor(attention_wq)?,
353 &cfgq,
354 lora_config,
355 vb,
356 ordering,
357 format!("model.layers.{layer_idx}.self_attn.q_proj"),
358 &mut count,
359 preload_adapters,
360 )?,
361 attention_wk: QLoraLinear::new(
362 QMatMul::from_qtensor(attention_wk)?,
363 &cfgk,
364 lora_config,
365 vb,
366 ordering,
367 format!("model.layers.{layer_idx}.self_attn.k_proj"),
368 &mut count,
369 preload_adapters,
370 )?,
371 attention_wv: QLoraLinear::new(
372 QMatMul::from_qtensor(attention_wv)?,
373 &cfgv,
374 lora_config,
375 vb,
376 ordering,
377 format!("model.layers.{layer_idx}.self_attn.v_proj"),
378 &mut count,
379 preload_adapters,
380 )?,
381 attention_wo: QLoraLinear::new(
382 QMatMul::from_qtensor(attention_wo)?,
383 &cfgo,
384 lora_config,
385 vb,
386 ordering,
387 format!("model.layers.{layer_idx}.self_attn.o_proj"),
388 &mut count,
389 preload_adapters,
390 )?,
391 attention_norm: QRmsNorm::new(attention_norm, 1e-5)?,
392 mlp_or_moe,
393 ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?,
394 n_head: ct.hparams.n_head as usize,
395 n_kv_head: ct.hparams.n_head as usize / gqa,
396 head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
397 rotary: rotary.clone().into(),
398 sdpa_params: SdpaParams {
399 n_kv_groups: ct.hparams.n_head as usize / n_kv_head,
400 use_flash_attn: false,
401 softcap: None,
402 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
403 sliding_window: None,
404 },
405 dtype,
406 })
407 }
408 if xlora_config.is_none() && preload_adapters.is_none() {
409 info!("Merging LoRA adapters.");
411 for layer in layers.iter_mut().tqdm() {
412 layer.attention_wk.merge_weights()?;
413 layer.attention_wo.merge_weights()?;
414 layer.attention_wq.merge_weights()?;
415 layer.attention_wv.merge_weights()?;
416 match &mut layer.mlp_or_moe {
417 MlpOrMoe::Mlp(ref mut m) => {
418 m.feed_forward_w1.merge_weights()?;
419 m.feed_forward_w2.merge_weights()?;
420 m.feed_forward_w3.merge_weights()?;
421 }
422 MlpOrMoe::MoE {
423 n_expert_used: _,
424 feed_forward_gate_inp: _,
425 experts,
426 } => {
427 for expert in experts {
428 expert.feed_forward_w1.merge_weights()?;
429 expert.feed_forward_w2.merge_weights()?;
430 expert.feed_forward_w3.merge_weights()?;
431 }
432 }
433 }
434 }
435 }
436 let output_cfg = get_lora_cfg(&output);
437 let output = QLoraLinear::new(
438 QMatMul::from_qtensor(output)?,
439 &output_cfg,
440 lora_config,
441 vb,
442 ordering,
443 "lm_head".to_string(),
444 &mut count,
445 preload_adapters,
446 )?;
447 if xlora_config.is_some() && output.is_lora() {
448 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
450 }
451 Ok(Self {
452 tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
453 layers,
454 norm,
455 output,
456 device: ct.device.clone(),
457 cache: EitherCache::Full(Cache::new(ct.hparams.n_layer as usize, true)),
458 xlora_classifier: xlora_config.map(|xlora_config| {
459 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
460 .unwrap()
461 }),
462 max_seq_len: MAX_SEQ_LEN as usize, mapper: None,
464 dtype,
465 })
466 }
467}
468
469impl ModelConfig::FromAdapterGGUF for ModelWeights {
470 #[allow(clippy::too_many_arguments)]
471 fn from_gguf<R: std::io::Seek + std::io::Read>(
472 mut ct: Content<'_, R>,
473 device: &Device,
474 lora_config: &[((String, String), LoraConfig)],
475 vb: &ShardedVarBuilder,
476 ordering: &Ordering,
477 xlora_config: Option<XLoraConfig>,
478 mapper: Box<dyn DeviceMapper + Send + Sync>,
479 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
480 dtype: DType,
481 ) -> Result<Self> {
482 verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
483
484 let metadata = ContentMetadata {
486 path_prefix: "llama",
487 metadata: ct.get_metadata(),
488 };
489 let PropsGGUF {
490 n_expert,
491 n_expert_used,
492 head_count,
493 head_count_kv,
494 block_count,
495 embedding_length,
496 rope_dim,
497 rms_norm_eps,
498 max_seq_len,
499 rope_freq_base,
500 key_length,
501 value_length,
502 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
503
504 let head_dim = key_length;
505 if key_length != value_length {
506 candle_core::bail!(
507 "Expected key_length == value_length, got {key_length} != {value_length}"
508 );
509 }
510
511 let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
512 let tok_embeddings = qtok_embeddings.dequantize(device)?;
513 let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
514 let output = if !ct.has_tensor("output.weight") {
515 ct.tensor("token_embd.weight", device)?
516 } else {
517 ct.tensor("output.weight", device)?
518 };
519 let mut layers = Vec::with_capacity(block_count);
520 let mut count = 0;
521
522 let mut ropes = HashMap::new();
523 for layer_idx in 0..block_count {
524 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
525 ropes.insert(
526 device.location(),
527 Arc::new(RotaryEmbedding::new(
528 rope_freq_base,
529 rope_dim,
530 max_seq_len,
531 device,
532 false,
533 dtype,
534 )?),
535 );
536 }
537
538 for layer_idx in NiceProgressBar::<_, 'b'>(
539 0..block_count,
540 "Loading repeating layers",
541 &MultiProgress::new(),
542 ) {
543 let prefix = format!("blk.{layer_idx}");
544 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
545 let rotary = ropes
546 .get(&device.location())
547 .expect("No RoPE for device location!")
548 .clone();
549
550 let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
551 let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
552 let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
553 let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
554 let mlp_or_moe = if n_expert <= 1 {
555 let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
556 let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
557 let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
558 let cfg_w1 = get_lora_cfg(&feed_forward_w1);
559 let cfg_w2 = get_lora_cfg(&feed_forward_w2);
560 let cfg_w3 = get_lora_cfg(&feed_forward_w3);
561 MlpOrMoe::Mlp(Mlp {
562 feed_forward_w1: QLoraLinear::new(
563 QMatMul::from_qtensor(feed_forward_w1)?,
564 &cfg_w1,
565 lora_config,
566 vb,
567 ordering,
568 format!("model.layers.{layer_idx}.mlp.gate_proj"),
569 &mut count,
570 preload_adapters,
571 )?,
572 feed_forward_w2: QLoraLinear::new(
573 QMatMul::from_qtensor(feed_forward_w2)?,
574 &cfg_w2,
575 lora_config,
576 vb,
577 ordering,
578 format!("model.layers.{layer_idx}.mlp.down_proj"),
579 &mut count,
580 preload_adapters,
581 )?,
582 feed_forward_w3: QLoraLinear::new(
583 QMatMul::from_qtensor(feed_forward_w3)?,
584 &cfg_w3,
585 lora_config,
586 vb,
587 ordering,
588 format!("model.layers.{layer_idx}.mlp.up_proj"),
589 &mut count,
590 preload_adapters,
591 )?,
592 })
593 } else {
594 let feed_forward_gate_inp =
595 ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?;
596 let mut experts = Vec::with_capacity(n_expert);
597 for i in 0..n_expert {
598 let feed_forward_w1 =
599 ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?;
600 let feed_forward_w2 =
601 ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?;
602 let feed_forward_w3 =
603 ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?;
604 let cfg_w1 = get_lora_cfg(&feed_forward_w1);
605 let cfg_w2 = get_lora_cfg(&feed_forward_w2);
606 let cfg_w3 = get_lora_cfg(&feed_forward_w3);
607 experts.push(Mlp {
608 feed_forward_w1: QLoraLinear::new(
609 QMatMul::from_qtensor(feed_forward_w1)?,
610 &cfg_w1,
611 lora_config,
612 vb,
613 ordering,
614 format!("model.layers.{layer_idx}.mlp.gate_proj.{i}"),
615 &mut count,
616 preload_adapters,
617 )?,
618 feed_forward_w2: QLoraLinear::new(
619 QMatMul::from_qtensor(feed_forward_w2)?,
620 &cfg_w2,
621 lora_config,
622 vb,
623 ordering,
624 format!("model.layers.{layer_idx}.mlp.down_proj.{i}"),
625 &mut count,
626 preload_adapters,
627 )?,
628 feed_forward_w3: QLoraLinear::new(
629 QMatMul::from_qtensor(feed_forward_w3)?,
630 &cfg_w3,
631 lora_config,
632 vb,
633 ordering,
634 format!("model.layers.{layer_idx}.mlp.up_proj.{i}"),
635 &mut count,
636 preload_adapters,
637 )?,
638 })
639 }
640 MlpOrMoe::MoE {
641 n_expert_used,
642 feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
643 experts,
644 }
645 };
646 let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
647 let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
648 let cfgq = get_lora_cfg(&attention_wq);
649 let cfgk = get_lora_cfg(&attention_wk);
650 let cfgv = get_lora_cfg(&attention_wv);
651 let cfgo = get_lora_cfg(&attention_wo);
652 layers.push(LayerWeights {
653 attention_wq: QLoraLinear::new(
654 QMatMul::from_qtensor(attention_wq)?,
655 &cfgq,
656 lora_config,
657 vb,
658 ordering,
659 format!("model.layers.{layer_idx}.self_attn.q_proj"),
660 &mut count,
661 preload_adapters,
662 )?,
663 attention_wk: QLoraLinear::new(
664 QMatMul::from_qtensor(attention_wk)?,
665 &cfgk,
666 lora_config,
667 vb,
668 ordering,
669 format!("model.layers.{layer_idx}.self_attn.k_proj"),
670 &mut count,
671 preload_adapters,
672 )?,
673 attention_wv: QLoraLinear::new(
674 QMatMul::from_qtensor(attention_wv)?,
675 &cfgv,
676 lora_config,
677 vb,
678 ordering,
679 format!("model.layers.{layer_idx}.self_attn.v_proj"),
680 &mut count,
681 preload_adapters,
682 )?,
683 attention_wo: QLoraLinear::new(
684 QMatMul::from_qtensor(attention_wo)?,
685 &cfgo,
686 lora_config,
687 vb,
688 ordering,
689 format!("model.layers.{layer_idx}.self_attn.o_proj"),
690 &mut count,
691 preload_adapters,
692 )?,
693 attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
694 mlp_or_moe,
695 ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
696 n_head: head_count,
697 n_kv_head: head_count_kv,
698 head_dim: embedding_length / head_count,
699 rotary: rotary.clone(),
700 sdpa_params: SdpaParams {
701 n_kv_groups: head_count / head_count_kv,
702 use_flash_attn: false,
703 softcap: None,
704 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
705 sliding_window: None,
706 },
707 dtype,
708 })
709 }
710 if xlora_config.is_none() && preload_adapters.is_none() {
711 info!("Merging LoRA adapters.");
713 for layer in layers.iter_mut().tqdm() {
714 layer.attention_wk.merge_weights()?;
715 layer.attention_wo.merge_weights()?;
716 layer.attention_wq.merge_weights()?;
717 layer.attention_wv.merge_weights()?;
718 match &mut layer.mlp_or_moe {
719 MlpOrMoe::Mlp(ref mut m) => {
720 m.feed_forward_w1.merge_weights()?;
721 m.feed_forward_w2.merge_weights()?;
722 m.feed_forward_w3.merge_weights()?;
723 }
724 MlpOrMoe::MoE {
725 n_expert_used: _,
726 feed_forward_gate_inp: _,
727 experts,
728 } => {
729 for expert in experts {
730 expert.feed_forward_w1.merge_weights()?;
731 expert.feed_forward_w2.merge_weights()?;
732 expert.feed_forward_w3.merge_weights()?;
733 }
734 }
735 }
736 }
737 }
738 let output_cfg = get_lora_cfg(&output);
739 let output = QLoraLinear::new(
740 QMatMul::from_qtensor(output)?,
741 &output_cfg,
742 lora_config,
743 vb,
744 ordering,
745 "lm_head".to_string(),
746 &mut count,
747 preload_adapters,
748 )?;
749 if xlora_config.is_some() && output.is_lora() {
750 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
752 }
753 Ok(Self {
754 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
755 layers,
756 norm,
757 output,
758 device: device.clone(),
759 cache: EitherCache::Full(Cache::new(block_count, true)),
760 xlora_classifier: xlora_config.map(|xlora_config| {
761 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
762 .unwrap()
763 }),
764 max_seq_len,
765 mapper: Some(mapper),
766 dtype,
767 })
768 }
769}
770
771impl ModelWeights {
772 pub fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
773 if self.xlora_classifier.is_some() {
774 candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
775 }
776 let mut sum = 0;
777 for layer in self.layers.iter_mut() {
778 sum += layer.attention_wk.activate(&adapter_names)?;
779 sum += layer.attention_wo.activate(&adapter_names)?;
780 sum += layer.attention_wq.activate(&adapter_names)?;
781 sum += layer.attention_wv.activate(&adapter_names)?;
782 match &mut layer.mlp_or_moe {
783 MlpOrMoe::Mlp(ref mut m) => {
784 sum += m.feed_forward_w1.activate(&adapter_names)?;
785 sum += m.feed_forward_w2.activate(&adapter_names)?;
786 sum += m.feed_forward_w3.activate(&adapter_names)?;
787 }
788 MlpOrMoe::MoE {
789 n_expert_used: _,
790 feed_forward_gate_inp: _,
791 experts,
792 } => {
793 for expert in experts {
794 sum += expert.feed_forward_w1.activate(&adapter_names)?;
795 sum += expert.feed_forward_w2.activate(&adapter_names)?;
796 sum += expert.feed_forward_w3.activate(&adapter_names)?;
797 }
798 }
799 }
800 }
801 Ok(sum)
802 }
803
804 #[allow(clippy::too_many_arguments)]
805 fn inner_forward(
806 &self,
807 x: &Tensor,
808 start_offsets: &[usize],
809 scalings: Option<Tensor>,
810 is_full_pass: bool,
811 no_kv_cache: bool,
812 is_scaling_pass: Option<f64>,
813 flash_params: &FlashParams,
814 ) -> Result<Tensor> {
815 let mut layer_in = self.tok_embeddings.forward(x)?;
816 let mut cache = if is_full_pass {
817 if no_kv_cache {
818 let mut new_cache = Vec::new();
819 for _ in 0..self.cache.full().xlora_lock().len() {
820 new_cache.push(None);
821 }
822
823 self.cache.full().xlora_lock().clone_from(&new_cache);
824 }
825 self.cache.full().xlora_lock()
826 } else {
827 self.cache.full().lock()
828 };
829 let mask =
830 CausalMasker.make_causal_mask_matrix(x, &*cache, self.dtype, self.layers[0].n_head)?;
831 for (i, layer) in self.layers.iter().enumerate() {
832 if let Some(ref mapper) = self.mapper {
833 layer_in = mapper.map(layer_in, i)?;
834 }
835 let x = layer_in;
836 let residual = &x;
837 let x = layer.attention_norm.forward(&x)?;
838 let attn = layer.forward_attn(
839 &x,
840 &mask.as_ref().map(|m| m.to_device(x.device()).unwrap()),
841 start_offsets,
842 &mut cache[i],
843 scalings.clone(),
844 self.xlora_classifier
845 .as_ref()
846 .map(|classifier| classifier.get_global_scaling_weight())
847 .unwrap_or(1.0),
848 is_scaling_pass,
849 flash_params,
850 )?;
851 let x = (attn + residual)?;
852
853 let residual = &x;
855 let x = layer.ffn_norm.forward(&x)?;
856 let x = layer.mlp_or_moe.forward(
857 &x,
858 scalings.clone(),
859 self.xlora_classifier
860 .as_ref()
861 .map(|classifier| classifier.get_global_scaling_weight())
862 .unwrap_or(1.0),
863 is_scaling_pass,
864 )?;
865 let x = (x + residual)?;
866 layer_in = x;
867 }
868 let layer_in = layer_in.to_device(&self.device)?;
869 self.norm.forward(&layer_in)
870 }
871
872 #[allow(clippy::too_many_arguments)]
873 pub fn forward(
874 &self,
875 input_ids: &Tensor,
876 input_ids_full: &Tensor,
877 seqlen_offsets: &[usize],
878 seqlen_offsets_full: &[usize],
879 no_kv_cache: bool,
880 non_granular_state: &Option<NonGranularState>,
881 context_lens: Vec<(usize, usize)>,
882 flash_params: &FlashParams,
883 flash_params_full: &FlashParams,
884 ) -> Result<Tensor> {
885 if self.xlora_classifier.is_some() {
886 let scalings = self.get_scalings(
887 input_ids,
888 input_ids_full,
889 seqlen_offsets,
890 seqlen_offsets_full,
891 no_kv_cache,
892 non_granular_state,
893 &vec![usize::MAX; context_lens.len()],
894 flash_params,
895 flash_params_full,
896 )?;
897
898 if no_kv_cache {
899 extract_logits(
900 &self.output.lora_forward(
901 &self
902 .inner_forward(
903 input_ids_full,
904 seqlen_offsets_full,
905 Some(scalings),
906 true,
907 no_kv_cache,
908 None,
909 flash_params_full,
910 )?
911 .contiguous()?,
912 None,
913 1.0,
914 None,
915 )?,
916 context_lens,
917 )
918 } else {
919 extract_logits(
921 &self.output.lora_forward(
922 &self
923 .inner_forward(
924 input_ids,
925 seqlen_offsets,
926 Some(scalings),
927 true,
928 no_kv_cache,
929 None,
930 flash_params,
931 )?
932 .contiguous()?,
933 None,
934 1.0,
935 None,
936 )?,
937 context_lens,
938 )
939 }
940 } else {
941 extract_logits(
942 &self.output.lora_forward(
943 &self
944 .inner_forward(
945 input_ids,
946 seqlen_offsets,
947 None,
948 false,
949 no_kv_cache,
950 None,
951 flash_params,
952 )?
953 .contiguous()?,
954 None,
955 1.0,
956 None,
957 )?,
958 context_lens,
959 )
960 }
961 }
962}
963
964impl ScalingsMaker for ModelWeights {
965 fn dtype(&self) -> DType {
966 DType::F32 }
968 fn get_cache(&self) -> &EitherCache {
969 &self.cache
970 }
971 fn get_classifier(&self) -> &XLoraClassifier {
972 self.xlora_classifier.as_ref().unwrap()
973 }
974 fn forward(
975 &self,
976 input_ids: &Tensor,
977 seqlen_offsets: &[usize],
978 scalings: Tensor,
979 is_full_pass: bool,
980 no_kv_cache: bool,
981 is_scaling_pass: Option<f64>,
982 _context_lens: &[usize],
983 flash_params: &FlashParams,
984 ) -> Result<Tensor> {
985 self.inner_forward(
986 input_ids,
987 seqlen_offsets,
988 Some(scalings),
989 is_full_pass,
990 no_kv_cache,
991 is_scaling_pass,
992 flash_params,
993 )
994 }
995}