mistralrs_quant/lora/
mod.rs

1mod static_lora;
2
3use std::{cell::RefCell, collections::HashSet};
4
5use candle_core::{DType, Result, Tensor};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8pub use static_lora::linear_no_bias_static_lora;
9
10use crate::{Shard, ShardedVarBuilder};
11
12thread_local! {
13    static ENGINE_APPLIED_LORAS: RefCell<Vec<LoraAdapter>> = const { RefCell::new(Vec::new()) };
14}
15
16/// Get the LoRA adapters for the current engine thread
17pub fn get_applied_loras() -> Vec<LoraAdapter> {
18    ENGINE_APPLIED_LORAS.with(|loras| loras.borrow().clone())
19}
20
21/// Push a LoRA adapter for the current engine thread
22pub fn push_applied_lora(adapter: LoraAdapter) {
23    ENGINE_APPLIED_LORAS.with(|loras| loras.borrow_mut().push(adapter));
24}
25
26/// Clear all LoRA adapters for the current engine thread
27pub fn clear_applied_loras() {
28    ENGINE_APPLIED_LORAS.with(|loras| loras.borrow_mut().clear());
29}
30
31pub const MULTI_LORA_DELIMITER: &str = ";";
32
33#[derive(Serialize, Deserialize, Debug, Clone)]
34pub struct StaticLoraConfig {
35    pub layer: String,
36    pub lora_alpha: f64,
37    pub r: usize,
38}
39
40#[derive(Clone, Debug, Deserialize, Serialize)]
41pub struct LoraConfig {
42    #[serde(rename = "r")]
43    pub rank: usize,
44    #[serde(rename = "lora_alpha")]
45    pub alpha: f64,
46    pub target_modules: HashSet<String>,
47}
48
49#[derive(Clone)]
50pub struct LoraAdapter {
51    pub config: LoraConfig,
52    pub weights: ShardedVarBuilder,
53}
54
55pub(crate) fn merge_lora_weights(
56    vb: &ShardedVarBuilder,
57    mut weight: Tensor,
58    in_dim: usize,
59    out_dim: usize,
60    shard: Shard,
61) -> Result<Tensor> {
62    let applied_loras = get_applied_loras();
63    for LoraAdapter { config, weights } in &applied_loras {
64        let target_modules = config
65            .target_modules
66            .iter()
67            .map(ToString::to_string)
68            .collect::<Vec<_>>()
69            .join("|");
70        let regex = Regex::new(&target_modules).map_err(candle_core::Error::msg)?;
71        if !regex.is_match(&vb.prefix()) {
72            continue;
73        }
74
75        // Handle base_model.model things from peft
76        let weights = if weights
77            .pp("base_model.model")
78            .pp(vb.prefix())
79            .contains_tensor("lora_A.weight")
80        {
81            weights.pp("base_model.model").pp(vb.prefix())
82        } else {
83            weights.pp(vb.prefix())
84        };
85
86        let a = weights.get_with_hints((config.rank, in_dim), "lora_A.weight", shard)?;
87        let b = weights.get_with_hints((out_dim, config.rank), "lora_B.weight", shard)?;
88        let scale = if config.rank > 0 {
89            config.alpha / config.rank as f64
90        } else {
91            1.0
92        };
93
94        let ab = if a.device().is_cpu() {
95            b.to_dtype(DType::F32)?.matmul(&a.to_dtype(DType::F32)?)?
96        } else {
97            b.matmul(&a)?
98        };
99
100        let delta_weight = (ab * scale)?;
101        weight = (weight + delta_weight.to_dtype(a.dtype())?)?;
102    }
103
104    Ok(weight)
105}