mistralrs_quant/lora/
mod.rs1mod static_lora;
2
3use std::{cell::RefCell, collections::HashSet};
4
5use candle_core::{DType, Result, Tensor};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8pub use static_lora::linear_no_bias_static_lora;
9
10use crate::{Shard, ShardedVarBuilder};
11
12thread_local! {
13 static ENGINE_APPLIED_LORAS: RefCell<Vec<LoraAdapter>> = const { RefCell::new(Vec::new()) };
14}
15
16pub fn get_applied_loras() -> Vec<LoraAdapter> {
18 ENGINE_APPLIED_LORAS.with(|loras| loras.borrow().clone())
19}
20
21pub fn push_applied_lora(adapter: LoraAdapter) {
23 ENGINE_APPLIED_LORAS.with(|loras| loras.borrow_mut().push(adapter));
24}
25
26pub fn clear_applied_loras() {
28 ENGINE_APPLIED_LORAS.with(|loras| loras.borrow_mut().clear());
29}
30
31pub const MULTI_LORA_DELIMITER: &str = ";";
32
33#[derive(Serialize, Deserialize, Debug, Clone)]
34pub struct StaticLoraConfig {
35 pub layer: String,
36 pub lora_alpha: f64,
37 pub r: usize,
38}
39
40#[derive(Clone, Debug, Deserialize, Serialize)]
41pub struct LoraConfig {
42 #[serde(rename = "r")]
43 pub rank: usize,
44 #[serde(rename = "lora_alpha")]
45 pub alpha: f64,
46 pub target_modules: HashSet<String>,
47}
48
49#[derive(Clone)]
50pub struct LoraAdapter {
51 pub config: LoraConfig,
52 pub weights: ShardedVarBuilder,
53}
54
55pub(crate) fn merge_lora_weights(
56 vb: &ShardedVarBuilder,
57 mut weight: Tensor,
58 in_dim: usize,
59 out_dim: usize,
60 shard: Shard,
61) -> Result<Tensor> {
62 let applied_loras = get_applied_loras();
63 for LoraAdapter { config, weights } in &applied_loras {
64 let target_modules = config
65 .target_modules
66 .iter()
67 .map(ToString::to_string)
68 .collect::<Vec<_>>()
69 .join("|");
70 let regex = Regex::new(&target_modules).map_err(candle_core::Error::msg)?;
71 if !regex.is_match(&vb.prefix()) {
72 continue;
73 }
74
75 let weights = if weights
77 .pp("base_model.model")
78 .pp(vb.prefix())
79 .contains_tensor("lora_A.weight")
80 {
81 weights.pp("base_model.model").pp(vb.prefix())
82 } else {
83 weights.pp(vb.prefix())
84 };
85
86 let a = weights.get_with_hints((config.rank, in_dim), "lora_A.weight", shard)?;
87 let b = weights.get_with_hints((out_dim, config.rank), "lora_B.weight", shard)?;
88 let scale = if config.rank > 0 {
89 config.alpha / config.rank as f64
90 } else {
91 1.0
92 };
93
94 let ab = if a.device().is_cpu() {
95 b.to_dtype(DType::F32)?.matmul(&a.to_dtype(DType::F32)?)?
96 } else {
97 b.matmul(&a)?
98 };
99
100 let delta_weight = (ab * scale)?;
101 weight = (weight + delta_weight.to_dtype(a.dtype())?)?;
102 }
103
104 Ok(weight)
105}