mistralrs_quant/lora/
mod.rs1mod static_lora;
2
3use std::{
4 collections::HashSet,
5 sync::{Arc, LazyLock, Mutex},
6};
7
8use candle_core::{DType, Result, Tensor};
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11pub use static_lora::linear_no_bias_static_lora;
12
13use crate::{Shard, ShardedVarBuilder};
14
15pub static APPLIED_LORAS: LazyLock<Arc<Mutex<Vec<LoraAdapter>>>> =
16 LazyLock::new(|| Arc::new(Mutex::new(Vec::new())));
17
18pub const MULTI_LORA_DELIMITER: &str = ";";
19
20#[derive(Serialize, Deserialize, Debug, Clone)]
21pub struct StaticLoraConfig {
22 pub layer: String,
23 pub lora_alpha: f64,
24 pub r: usize,
25}
26
27#[derive(Clone, Debug, Deserialize, Serialize)]
28pub struct LoraConfig {
29 #[serde(rename = "r")]
30 pub rank: usize,
31 #[serde(rename = "lora_alpha")]
32 pub alpha: f64,
33 pub target_modules: HashSet<String>,
34}
35
36pub struct LoraAdapter {
37 pub config: LoraConfig,
38 pub weights: ShardedVarBuilder,
39}
40
41pub(crate) fn merge_lora_weights(
42 vb: &ShardedVarBuilder,
43 mut weight: Tensor,
44 in_dim: usize,
45 out_dim: usize,
46 shard: Shard,
47) -> Result<Tensor> {
48 for LoraAdapter { config, weights } in &*APPLIED_LORAS.lock().expect("No loras initialized.") {
49 let target_modules = config
50 .target_modules
51 .iter()
52 .map(ToString::to_string)
53 .collect::<Vec<_>>()
54 .join("|");
55 let regex = Regex::new(&target_modules).map_err(candle_core::Error::msg)?;
56 if !regex.is_match(&vb.prefix()) {
57 continue;
58 }
59 let weights = weights.set_prefix(vb.prefix());
60
61 let a = weights.get_with_hints((config.rank, in_dim), "lora_A.weight", shard)?;
62 let b = weights.get_with_hints((out_dim, config.rank), "lora_B.weight", shard)?;
63 let scale = if config.rank > 0 {
64 config.alpha / config.rank as f64
65 } else {
66 1.0
67 };
68
69 let ab = if a.device().is_cpu() {
70 b.to_dtype(DType::F32)?.matmul(&a.to_dtype(DType::F32)?)?
71 } else {
72 b.matmul(&a)?
73 };
74
75 let delta_weight = (ab * scale)?;
76 weight = (weight + delta_weight.to_dtype(a.dtype())?)?;
77 }
78
79 Ok(weight)
80}