1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::attention::SdpaParams;
7use crate::gguf::Content;
8use crate::lora::{get_lora_cfg, LinearLayerLike, LoraConfig, Merge, Ordering, QLoraLinear};
9use crate::pipeline::text_models_inputs_processor::FlashParams;
10use crate::utils::progress::NiceProgressBar;
11use candle_core::quantized::ggml_file;
12use candle_core::quantized::QMatMul;
13use candle_core::{DType, Device, Result, Tensor};
14use candle_nn::{Embedding, Module};
15use indicatif::MultiProgress;
16use mistralrs_quant::{MatMul, ShardedVarBuilder};
17use tqdm::Iter;
18use tracing::info;
19
20use crate::device_map::DeviceMapper;
21use crate::layers::{CausalMasker, QRmsNorm, RotaryEmbedding, Sdpa};
22use crate::pipeline::{extract_logits, Cache, EitherCache};
23
24use super::classifier::XLoraClassifier;
25use super::{verify_sanity_adapters, NonGranularState, ScalingsMaker, XLoraConfig};
26use crate::models::quantized_llama::PropsGGUF;
27use crate::utils::gguf_metadata::ContentMetadata;
28use crate::utils::model_config as ModelConfig;
29
30const MAX_SEQ_LEN: u32 = 4096;
31const SUPPORTED_LAYERS: [&str; 8] = [
32 "self_attn.q_proj",
33 "self_attn.k_proj",
34 "self_attn.v_proj",
35 "self_attn.o_proj",
36 "mlp.up_proj",
37 "mlp.down_proj",
38 "mlp.gate_proj",
39 "lm_head",
40];
41
42#[derive(Debug)]
43struct Mlp {
44 feed_forward_w1: QLoraLinear,
45 feed_forward_w2: QLoraLinear,
46 feed_forward_w3: QLoraLinear,
47}
48
49impl Mlp {
50 fn forward(
51 &self,
52 xs: &Tensor,
53 scalings: Option<Tensor>,
54 global_scaling_weight: f64,
55 is_scaling_pass: Option<f64>,
56 ) -> Result<Tensor> {
57 let w1 = self.feed_forward_w1.lora_forward(
58 xs,
59 scalings.clone(),
60 global_scaling_weight,
61 is_scaling_pass,
62 )?;
63 let w3 = self.feed_forward_w3.lora_forward(
64 xs,
65 scalings.clone(),
66 global_scaling_weight,
67 is_scaling_pass,
68 )?;
69 self.feed_forward_w2.lora_forward(
70 &(candle_nn::ops::silu(&w1)? * w3)?,
71 scalings.clone(),
72 global_scaling_weight,
73 is_scaling_pass,
74 )
75 }
76}
77
78#[derive(Debug)]
79enum MlpOrMoe {
80 Mlp(Mlp),
81 MoE {
82 n_expert_used: usize,
83 feed_forward_gate_inp: QMatMul,
84 experts: Vec<Mlp>,
85 },
86}
87
88impl MlpOrMoe {
89 fn forward(
90 &self,
91 xs: &Tensor,
92 scalings: Option<Tensor>,
93 global_scaling_weight: f64,
94 is_scaling_pass: Option<f64>,
95 ) -> Result<Tensor> {
96 match self {
97 Self::MoE {
98 feed_forward_gate_inp,
99 experts,
100 n_expert_used,
101 } => {
102 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
103 let xs = xs.reshape(((), hidden_dim))?;
104 let router_logits = MatMul.qmatmul(&xs, feed_forward_gate_inp)?;
105 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
106
107 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
110
111 let mut top_x = vec![vec![]; experts.len()];
114 let mut selected_rws = vec![vec![]; experts.len()];
115 for (row_idx, rw) in routing_weights.iter().enumerate() {
116 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
117 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
118 let mut sum_routing_weights = 0f32;
119 for &expert_idx in dst.iter().take(*n_expert_used) {
120 let expert_idx = expert_idx as usize;
121 let routing_weight = rw[expert_idx];
122 sum_routing_weights += routing_weight;
123 top_x[expert_idx].push(row_idx as u32);
124 }
125 for &expert_idx in dst.iter().take(*n_expert_used) {
126 let expert_idx = expert_idx as usize;
127 let routing_weight = rw[expert_idx];
128 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
129 }
130 }
131
132 let mut ys = xs.zeros_like()?;
136 for (expert_idx, expert_layer) in experts.iter().enumerate() {
137 let top_x = &top_x[expert_idx];
138 if top_x.is_empty() {
139 continue;
140 }
141 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
142 let selected_rws =
143 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
144 .reshape(((), 1))?;
145 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
149 let current_hidden_states = expert_layer.forward(
151 ¤t_state,
152 scalings.clone(),
153 global_scaling_weight,
154 is_scaling_pass,
155 )?;
156 let current_hidden_states =
157 current_hidden_states.broadcast_mul(&selected_rws)?;
158 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
159 }
160
161 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
162 Ok(ys)
163 }
164 Self::Mlp(mlp) => {
165 mlp.forward(xs, scalings.clone(), global_scaling_weight, is_scaling_pass)
166 }
167 }
168 }
169}
170
171struct LayerWeights {
172 attention_wq: QLoraLinear,
173 attention_wk: QLoraLinear,
174 attention_wv: QLoraLinear,
175 attention_wo: QLoraLinear,
176 attention_norm: QRmsNorm,
177 mlp_or_moe: MlpOrMoe,
178 ffn_norm: QRmsNorm,
179 n_head: usize,
180 n_kv_head: usize,
181 head_dim: usize,
182 rotary: Arc<RotaryEmbedding>,
183 sdpa_params: SdpaParams,
184 dtype: DType,
185}
186
187impl LayerWeights {
188 #[allow(clippy::too_many_arguments)]
189 fn forward_attn(
190 &self,
191 x: &Tensor,
192 mask: &Option<Tensor>,
193 start_offsets: &[usize],
194 kv_cache: &mut Option<(Tensor, Tensor)>,
195 scalings: Option<Tensor>,
196 global_scaling_weight: f64,
197 is_scaling_pass: Option<f64>,
198 flash_params: &FlashParams,
199 ) -> Result<Tensor> {
200 let (b_sz, seq_len, n_embd) = x.dims3()?;
201 let q = self
202 .attention_wq
203 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
204 .to_dtype(self.dtype)?;
205 let k = self
206 .attention_wk
207 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
208 .to_dtype(self.dtype)?;
209 let v = self
210 .attention_wv
211 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
212 .to_dtype(self.dtype)?;
213
214 let (q, k, v) = if seq_len != 1 {
215 let q = q
216 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
217 .transpose(1, 2)?;
218 let k = k
219 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
220 .transpose(1, 2)?;
221 let v = v
222 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
223 .transpose(1, 2)?;
224 (q, k, v)
225 } else {
226 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
227 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
228 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
229 (q, k, v)
230 };
231
232 let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
233
234 let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
235
236 let y = Sdpa.run_attention(
237 &q,
238 &k,
239 &v,
240 mask.as_ref(),
241 Some(flash_params),
242 &self.sdpa_params,
243 )?;
244
245 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
246 let y = self.attention_wo.lora_forward(
247 &y.to_dtype(x.dtype())?,
248 scalings.clone(),
249 global_scaling_weight,
250 is_scaling_pass,
251 )?;
252 Ok(y)
253 }
254}
255
256pub struct ModelWeights {
257 tok_embeddings: Embedding,
258 layers: Vec<LayerWeights>,
259 norm: QRmsNorm,
260 output: QLoraLinear,
261 pub device: Device,
262 pub cache: EitherCache,
263 xlora_classifier: Option<XLoraClassifier>,
264 pub max_seq_len: usize,
265 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
266 dtype: DType,
267}
268
269impl ModelConfig::FromAdapterGGML for ModelWeights {
270 fn from_ggml(
271 mut ct: ggml_file::Content,
272 gqa: usize,
273 lora_config: &[((String, String), LoraConfig)],
274 vb: &ShardedVarBuilder,
275 ordering: &Ordering,
276 xlora_config: Option<XLoraConfig>,
277 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
278 dtype: DType,
279 ) -> Result<Self> {
280 let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
281 let rotary = RotaryEmbedding::new_partial(
282 10000.,
283 ct.hparams.n_rot as usize,
284 MAX_SEQ_LEN as usize,
285 &ct.device,
286 false,
287 dtype,
288 )?;
289 let tok_embeddings = ct.remove("tok_embeddings.weight")?;
290 let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
291 let norm = QRmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
292 let output = ct.remove("output.weight")?;
293 let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
294 let mut count = 0;
295 for layer_idx in 0..ct.hparams.n_layer {
296 let prefix = format!("layers.{layer_idx}");
297 let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
298 let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
299 let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
300 let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
301 let mlp_or_moe = {
302 let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
303 let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
304 let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
305 let cfg_w1 = get_lora_cfg(&feed_forward_w1);
306 let cfg_w2 = get_lora_cfg(&feed_forward_w2);
307 let cfg_w3 = get_lora_cfg(&feed_forward_w3);
308 MlpOrMoe::Mlp(Mlp {
309 feed_forward_w1: QLoraLinear::new(
310 QMatMul::from_qtensor(feed_forward_w1)?,
311 &cfg_w1,
312 lora_config,
313 vb,
314 ordering,
315 format!("model.layers.{layer_idx}.mlp.gate_proj"),
316 &mut count,
317 preload_adapters,
318 )?,
319 feed_forward_w2: QLoraLinear::new(
320 QMatMul::from_qtensor(feed_forward_w2)?,
321 &cfg_w2,
322 lora_config,
323 vb,
324 ordering,
325 format!("model.layers.{layer_idx}.mlp.down_proj"),
326 &mut count,
327 preload_adapters,
328 )?,
329 feed_forward_w3: QLoraLinear::new(
330 QMatMul::from_qtensor(feed_forward_w3)?,
331 &cfg_w3,
332 lora_config,
333 vb,
334 ordering,
335 format!("model.layers.{layer_idx}.mlp.up_proj"),
336 &mut count,
337 preload_adapters,
338 )?,
339 })
340 };
341 let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
342 let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
343 let cfgq = get_lora_cfg(&attention_wq);
344 let cfgk = get_lora_cfg(&attention_wk);
345 let cfgv = get_lora_cfg(&attention_wv);
346 let cfgo = get_lora_cfg(&attention_wo);
347 let n_kv_head = ct.hparams.n_head as usize / gqa;
348 layers.push(LayerWeights {
349 attention_wq: QLoraLinear::new(
350 QMatMul::from_qtensor(attention_wq)?,
351 &cfgq,
352 lora_config,
353 vb,
354 ordering,
355 format!("model.layers.{layer_idx}.self_attn.q_proj"),
356 &mut count,
357 preload_adapters,
358 )?,
359 attention_wk: QLoraLinear::new(
360 QMatMul::from_qtensor(attention_wk)?,
361 &cfgk,
362 lora_config,
363 vb,
364 ordering,
365 format!("model.layers.{layer_idx}.self_attn.k_proj"),
366 &mut count,
367 preload_adapters,
368 )?,
369 attention_wv: QLoraLinear::new(
370 QMatMul::from_qtensor(attention_wv)?,
371 &cfgv,
372 lora_config,
373 vb,
374 ordering,
375 format!("model.layers.{layer_idx}.self_attn.v_proj"),
376 &mut count,
377 preload_adapters,
378 )?,
379 attention_wo: QLoraLinear::new(
380 QMatMul::from_qtensor(attention_wo)?,
381 &cfgo,
382 lora_config,
383 vb,
384 ordering,
385 format!("model.layers.{layer_idx}.self_attn.o_proj"),
386 &mut count,
387 preload_adapters,
388 )?,
389 attention_norm: QRmsNorm::new(attention_norm, 1e-5)?,
390 mlp_or_moe,
391 ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?,
392 n_head: ct.hparams.n_head as usize,
393 n_kv_head: ct.hparams.n_head as usize / gqa,
394 head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
395 rotary: rotary.clone().into(),
396 sdpa_params: SdpaParams {
397 n_kv_groups: ct.hparams.n_head as usize / n_kv_head,
398 use_flash_attn: false,
399 softcap: None,
400 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
401 sliding_window: None,
402 },
403 dtype,
404 })
405 }
406 if xlora_config.is_none() && preload_adapters.is_none() {
407 info!("Merging LoRA adapters.");
409 for layer in layers.iter_mut().tqdm() {
410 layer.attention_wk.merge_weights()?;
411 layer.attention_wo.merge_weights()?;
412 layer.attention_wq.merge_weights()?;
413 layer.attention_wv.merge_weights()?;
414 match &mut layer.mlp_or_moe {
415 MlpOrMoe::Mlp(ref mut m) => {
416 m.feed_forward_w1.merge_weights()?;
417 m.feed_forward_w2.merge_weights()?;
418 m.feed_forward_w3.merge_weights()?;
419 }
420 MlpOrMoe::MoE {
421 n_expert_used: _,
422 feed_forward_gate_inp: _,
423 experts,
424 } => {
425 for expert in experts {
426 expert.feed_forward_w1.merge_weights()?;
427 expert.feed_forward_w2.merge_weights()?;
428 expert.feed_forward_w3.merge_weights()?;
429 }
430 }
431 }
432 }
433 }
434 let output_cfg = get_lora_cfg(&output);
435 let output = QLoraLinear::new(
436 QMatMul::from_qtensor(output)?,
437 &output_cfg,
438 lora_config,
439 vb,
440 ordering,
441 "lm_head".to_string(),
442 &mut count,
443 preload_adapters,
444 )?;
445 if xlora_config.is_some() && output.is_lora() {
446 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
448 }
449 Ok(Self {
450 tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
451 layers,
452 norm,
453 output,
454 device: ct.device.clone(),
455 cache: EitherCache::Full(Cache::new(ct.hparams.n_layer as usize, true)),
456 xlora_classifier: xlora_config.map(|xlora_config| {
457 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
458 .unwrap()
459 }),
460 max_seq_len: MAX_SEQ_LEN as usize, mapper: None,
462 dtype,
463 })
464 }
465}
466
467impl ModelConfig::FromAdapterGGUF for ModelWeights {
468 #[allow(clippy::too_many_arguments)]
469 fn from_gguf<R: std::io::Seek + std::io::Read>(
470 mut ct: Content<'_, R>,
471 device: &Device,
472 lora_config: &[((String, String), LoraConfig)],
473 vb: &ShardedVarBuilder,
474 ordering: &Ordering,
475 xlora_config: Option<XLoraConfig>,
476 mapper: Box<dyn DeviceMapper + Send + Sync>,
477 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
478 dtype: DType,
479 ) -> Result<Self> {
480 verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
481
482 let metadata = ContentMetadata {
484 path_prefix: "llama",
485 metadata: ct.get_metadata(),
486 };
487 let PropsGGUF {
488 n_expert,
489 n_expert_used,
490 head_count,
491 head_count_kv,
492 block_count,
493 embedding_length,
494 rope_dim,
495 rms_norm_eps,
496 max_seq_len,
497 rope_freq_base,
498 key_length,
499 value_length,
500 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
501
502 let head_dim = key_length;
503 if key_length != value_length {
504 candle_core::bail!(
505 "Expected key_length == value_length, got {key_length} != {value_length}"
506 );
507 }
508
509 let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
510 let tok_embeddings = qtok_embeddings.dequantize(device)?;
511 let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
512 let output = if !ct.has_tensor("output.weight") {
513 ct.tensor("token_embd.weight", device)?
514 } else {
515 ct.tensor("output.weight", device)?
516 };
517 let mut layers = Vec::with_capacity(block_count);
518 let mut count = 0;
519
520 let mut ropes = HashMap::new();
521 for layer_idx in 0..block_count {
522 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
523 ropes.insert(
524 device.location(),
525 Arc::new(RotaryEmbedding::new(
526 rope_freq_base,
527 rope_dim,
528 max_seq_len,
529 device,
530 false,
531 dtype,
532 )?),
533 );
534 }
535
536 for layer_idx in NiceProgressBar::<_, 'b'>(
537 0..block_count,
538 "Loading repeating layers",
539 &MultiProgress::new(),
540 ) {
541 let prefix = format!("blk.{layer_idx}");
542 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
543 let rotary = ropes
544 .get(&device.location())
545 .expect("No RoPE for device location!")
546 .clone();
547
548 let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
549 let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
550 let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
551 let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
552 let mlp_or_moe = if n_expert <= 1 {
553 let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
554 let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
555 let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
556 let cfg_w1 = get_lora_cfg(&feed_forward_w1);
557 let cfg_w2 = get_lora_cfg(&feed_forward_w2);
558 let cfg_w3 = get_lora_cfg(&feed_forward_w3);
559 MlpOrMoe::Mlp(Mlp {
560 feed_forward_w1: QLoraLinear::new(
561 QMatMul::from_qtensor(feed_forward_w1)?,
562 &cfg_w1,
563 lora_config,
564 vb,
565 ordering,
566 format!("model.layers.{layer_idx}.mlp.gate_proj"),
567 &mut count,
568 preload_adapters,
569 )?,
570 feed_forward_w2: QLoraLinear::new(
571 QMatMul::from_qtensor(feed_forward_w2)?,
572 &cfg_w2,
573 lora_config,
574 vb,
575 ordering,
576 format!("model.layers.{layer_idx}.mlp.down_proj"),
577 &mut count,
578 preload_adapters,
579 )?,
580 feed_forward_w3: QLoraLinear::new(
581 QMatMul::from_qtensor(feed_forward_w3)?,
582 &cfg_w3,
583 lora_config,
584 vb,
585 ordering,
586 format!("model.layers.{layer_idx}.mlp.up_proj"),
587 &mut count,
588 preload_adapters,
589 )?,
590 })
591 } else {
592 let feed_forward_gate_inp =
593 ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?;
594 let mut experts = Vec::with_capacity(n_expert);
595 for i in 0..n_expert {
596 let feed_forward_w1 =
597 ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?;
598 let feed_forward_w2 =
599 ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?;
600 let feed_forward_w3 =
601 ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?;
602 let cfg_w1 = get_lora_cfg(&feed_forward_w1);
603 let cfg_w2 = get_lora_cfg(&feed_forward_w2);
604 let cfg_w3 = get_lora_cfg(&feed_forward_w3);
605 experts.push(Mlp {
606 feed_forward_w1: QLoraLinear::new(
607 QMatMul::from_qtensor(feed_forward_w1)?,
608 &cfg_w1,
609 lora_config,
610 vb,
611 ordering,
612 format!("model.layers.{layer_idx}.mlp.gate_proj.{i}"),
613 &mut count,
614 preload_adapters,
615 )?,
616 feed_forward_w2: QLoraLinear::new(
617 QMatMul::from_qtensor(feed_forward_w2)?,
618 &cfg_w2,
619 lora_config,
620 vb,
621 ordering,
622 format!("model.layers.{layer_idx}.mlp.down_proj.{i}"),
623 &mut count,
624 preload_adapters,
625 )?,
626 feed_forward_w3: QLoraLinear::new(
627 QMatMul::from_qtensor(feed_forward_w3)?,
628 &cfg_w3,
629 lora_config,
630 vb,
631 ordering,
632 format!("model.layers.{layer_idx}.mlp.up_proj.{i}"),
633 &mut count,
634 preload_adapters,
635 )?,
636 })
637 }
638 MlpOrMoe::MoE {
639 n_expert_used,
640 feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
641 experts,
642 }
643 };
644 let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
645 let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
646 let cfgq = get_lora_cfg(&attention_wq);
647 let cfgk = get_lora_cfg(&attention_wk);
648 let cfgv = get_lora_cfg(&attention_wv);
649 let cfgo = get_lora_cfg(&attention_wo);
650 layers.push(LayerWeights {
651 attention_wq: QLoraLinear::new(
652 QMatMul::from_qtensor(attention_wq)?,
653 &cfgq,
654 lora_config,
655 vb,
656 ordering,
657 format!("model.layers.{layer_idx}.self_attn.q_proj"),
658 &mut count,
659 preload_adapters,
660 )?,
661 attention_wk: QLoraLinear::new(
662 QMatMul::from_qtensor(attention_wk)?,
663 &cfgk,
664 lora_config,
665 vb,
666 ordering,
667 format!("model.layers.{layer_idx}.self_attn.k_proj"),
668 &mut count,
669 preload_adapters,
670 )?,
671 attention_wv: QLoraLinear::new(
672 QMatMul::from_qtensor(attention_wv)?,
673 &cfgv,
674 lora_config,
675 vb,
676 ordering,
677 format!("model.layers.{layer_idx}.self_attn.v_proj"),
678 &mut count,
679 preload_adapters,
680 )?,
681 attention_wo: QLoraLinear::new(
682 QMatMul::from_qtensor(attention_wo)?,
683 &cfgo,
684 lora_config,
685 vb,
686 ordering,
687 format!("model.layers.{layer_idx}.self_attn.o_proj"),
688 &mut count,
689 preload_adapters,
690 )?,
691 attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
692 mlp_or_moe,
693 ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
694 n_head: head_count,
695 n_kv_head: head_count_kv,
696 head_dim: embedding_length / head_count,
697 rotary: rotary.clone(),
698 sdpa_params: SdpaParams {
699 n_kv_groups: head_count / head_count_kv,
700 use_flash_attn: false,
701 softcap: None,
702 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
703 sliding_window: None,
704 },
705 dtype,
706 })
707 }
708 if xlora_config.is_none() && preload_adapters.is_none() {
709 info!("Merging LoRA adapters.");
711 for layer in layers.iter_mut().tqdm() {
712 layer.attention_wk.merge_weights()?;
713 layer.attention_wo.merge_weights()?;
714 layer.attention_wq.merge_weights()?;
715 layer.attention_wv.merge_weights()?;
716 match &mut layer.mlp_or_moe {
717 MlpOrMoe::Mlp(ref mut m) => {
718 m.feed_forward_w1.merge_weights()?;
719 m.feed_forward_w2.merge_weights()?;
720 m.feed_forward_w3.merge_weights()?;
721 }
722 MlpOrMoe::MoE {
723 n_expert_used: _,
724 feed_forward_gate_inp: _,
725 experts,
726 } => {
727 for expert in experts {
728 expert.feed_forward_w1.merge_weights()?;
729 expert.feed_forward_w2.merge_weights()?;
730 expert.feed_forward_w3.merge_weights()?;
731 }
732 }
733 }
734 }
735 }
736 let output_cfg = get_lora_cfg(&output);
737 let output = QLoraLinear::new(
738 QMatMul::from_qtensor(output)?,
739 &output_cfg,
740 lora_config,
741 vb,
742 ordering,
743 "lm_head".to_string(),
744 &mut count,
745 preload_adapters,
746 )?;
747 if xlora_config.is_some() && output.is_lora() {
748 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
750 }
751 Ok(Self {
752 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
753 layers,
754 norm,
755 output,
756 device: device.clone(),
757 cache: EitherCache::Full(Cache::new(block_count, true)),
758 xlora_classifier: xlora_config.map(|xlora_config| {
759 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
760 .unwrap()
761 }),
762 max_seq_len,
763 mapper: Some(mapper),
764 dtype,
765 })
766 }
767}
768
769impl ModelWeights {
770 #[allow(clippy::too_many_arguments)]
771 fn inner_forward(
772 &self,
773 x: &Tensor,
774 start_offsets: &[usize],
775 scalings: Option<Tensor>,
776 is_full_pass: bool,
777 no_kv_cache: bool,
778 is_scaling_pass: Option<f64>,
779 flash_params: &FlashParams,
780 ) -> Result<Tensor> {
781 let mut layer_in = self.tok_embeddings.forward(x)?;
782 let mut cache = if is_full_pass {
783 if no_kv_cache {
784 let mut new_cache = Vec::new();
785 for _ in 0..self.cache.full().xlora_lock().len() {
786 new_cache.push(None);
787 }
788
789 self.cache.full().xlora_lock().clone_from(&new_cache);
790 }
791 self.cache.full().xlora_lock()
792 } else {
793 self.cache.full().lock()
794 };
795 let mask =
796 CausalMasker.make_causal_mask_matrix(x, &*cache, self.dtype, self.layers[0].n_head)?;
797 for (i, layer) in self.layers.iter().enumerate() {
798 if let Some(ref mapper) = self.mapper {
799 layer_in = mapper.map(layer_in, i)?;
800 }
801 let x = layer_in;
802 let residual = &x;
803 let x = layer.attention_norm.forward(&x)?;
804 let attn = layer.forward_attn(
805 &x,
806 &mask.as_ref().map(|m| m.to_device(x.device()).unwrap()),
807 start_offsets,
808 &mut cache[i],
809 scalings.clone(),
810 self.xlora_classifier
811 .as_ref()
812 .map(|classifier| classifier.get_global_scaling_weight())
813 .unwrap_or(1.0),
814 is_scaling_pass,
815 flash_params,
816 )?;
817 let x = (attn + residual)?;
818
819 let residual = &x;
821 let x = layer.ffn_norm.forward(&x)?;
822 let x = layer.mlp_or_moe.forward(
823 &x,
824 scalings.clone(),
825 self.xlora_classifier
826 .as_ref()
827 .map(|classifier| classifier.get_global_scaling_weight())
828 .unwrap_or(1.0),
829 is_scaling_pass,
830 )?;
831 let x = (x + residual)?;
832 layer_in = x;
833 }
834 let layer_in = layer_in.to_device(&self.device)?;
835 self.norm.forward(&layer_in)
836 }
837
838 #[allow(clippy::too_many_arguments)]
839 pub fn forward(
840 &self,
841 input_ids: &Tensor,
842 input_ids_full: &Tensor,
843 seqlen_offsets: &[usize],
844 seqlen_offsets_full: &[usize],
845 no_kv_cache: bool,
846 non_granular_state: &Option<NonGranularState>,
847 context_lens: Vec<(usize, usize)>,
848 flash_params: &FlashParams,
849 flash_params_full: &FlashParams,
850 ) -> Result<Tensor> {
851 if self.xlora_classifier.is_some() {
852 let scalings = self.get_scalings(
853 input_ids,
854 input_ids_full,
855 seqlen_offsets,
856 seqlen_offsets_full,
857 no_kv_cache,
858 non_granular_state,
859 &vec![usize::MAX; context_lens.len()],
860 flash_params,
861 flash_params_full,
862 )?;
863
864 if no_kv_cache {
865 extract_logits(
866 &self.output.lora_forward(
867 &self
868 .inner_forward(
869 input_ids_full,
870 seqlen_offsets_full,
871 Some(scalings),
872 true,
873 no_kv_cache,
874 None,
875 flash_params_full,
876 )?
877 .contiguous()?,
878 None,
879 1.0,
880 None,
881 )?,
882 context_lens,
883 )
884 } else {
885 extract_logits(
887 &self.output.lora_forward(
888 &self
889 .inner_forward(
890 input_ids,
891 seqlen_offsets,
892 Some(scalings),
893 true,
894 no_kv_cache,
895 None,
896 flash_params,
897 )?
898 .contiguous()?,
899 None,
900 1.0,
901 None,
902 )?,
903 context_lens,
904 )
905 }
906 } else {
907 extract_logits(
908 &self.output.lora_forward(
909 &self
910 .inner_forward(
911 input_ids,
912 seqlen_offsets,
913 None,
914 false,
915 no_kv_cache,
916 None,
917 flash_params,
918 )?
919 .contiguous()?,
920 None,
921 1.0,
922 None,
923 )?,
924 context_lens,
925 )
926 }
927 }
928}
929
930impl ScalingsMaker for ModelWeights {
931 fn dtype(&self) -> DType {
932 DType::F32 }
934 fn get_cache(&self) -> &EitherCache {
935 &self.cache
936 }
937 fn get_classifier(&self) -> &XLoraClassifier {
938 self.xlora_classifier.as_ref().unwrap()
939 }
940 fn forward(
941 &self,
942 input_ids: &Tensor,
943 seqlen_offsets: &[usize],
944 scalings: Tensor,
945 is_full_pass: bool,
946 no_kv_cache: bool,
947 is_scaling_pass: Option<f64>,
948 _context_lens: &[usize],
949 flash_params: &FlashParams,
950 ) -> Result<Tensor> {
951 self.inner_forward(
952 input_ids,
953 seqlen_offsets,
954 Some(scalings),
955 is_full_pass,
956 no_kv_cache,
957 is_scaling_pass,
958 flash_params,
959 )
960 }
961}