mistralrs_quant/lora/
mod.rs

1mod static_lora;
2
3use std::{
4    collections::HashSet,
5    sync::{Arc, LazyLock, Mutex},
6};
7
8use candle_core::{DType, Result, Tensor};
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11pub use static_lora::linear_no_bias_static_lora;
12
13use crate::{Shard, ShardedVarBuilder};
14
15pub static APPLIED_LORAS: LazyLock<Arc<Mutex<Vec<LoraAdapter>>>> =
16    LazyLock::new(|| Arc::new(Mutex::new(Vec::new())));
17
18pub const MULTI_LORA_DELIMITER: &str = ";";
19
20#[derive(Serialize, Deserialize, Debug, Clone)]
21pub struct StaticLoraConfig {
22    pub layer: String,
23    pub lora_alpha: f64,
24    pub r: usize,
25}
26
27#[derive(Clone, Debug, Deserialize, Serialize)]
28pub struct LoraConfig {
29    #[serde(rename = "r")]
30    pub rank: usize,
31    #[serde(rename = "lora_alpha")]
32    pub alpha: f64,
33    pub target_modules: HashSet<String>,
34}
35
36pub struct LoraAdapter {
37    pub config: LoraConfig,
38    pub weights: ShardedVarBuilder,
39}
40
41pub(crate) fn merge_lora_weights(
42    vb: &ShardedVarBuilder,
43    mut weight: Tensor,
44    in_dim: usize,
45    out_dim: usize,
46    shard: Shard,
47) -> Result<Tensor> {
48    for LoraAdapter { config, weights } in &*APPLIED_LORAS.lock().expect("No loras initialized.") {
49        let target_modules = config
50            .target_modules
51            .iter()
52            .map(ToString::to_string)
53            .collect::<Vec<_>>()
54            .join("|");
55        let regex = Regex::new(&target_modules).map_err(candle_core::Error::msg)?;
56        if !regex.is_match(&vb.prefix()) {
57            continue;
58        }
59        let weights = weights.set_prefix(vb.prefix());
60
61        let a = weights.get_with_hints((config.rank, in_dim), "lora_A.weight", shard)?;
62        let b = weights.get_with_hints((out_dim, config.rank), "lora_B.weight", shard)?;
63        let scale = if config.rank > 0 {
64            config.alpha / config.rank as f64
65        } else {
66            1.0
67        };
68
69        let ab = if a.device().is_cpu() {
70            b.to_dtype(DType::F32)?.matmul(&a.to_dtype(DType::F32)?)?
71        } else {
72            b.matmul(&a)?
73        };
74
75        let delta_weight = (ab * scale)?;
76        weight = (weight + delta_weight.to_dtype(a.dtype())?)?;
77    }
78
79    Ok(weight)
80}