mistralrs_quant/lora/
static_lora.rs1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{DType, Result};
4use candle_nn::Linear;
5use regex::Regex;
6
7use crate::{DummyLayer, QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear};
8
9use super::StaticLoraConfig;
10
11pub fn linear_no_bias_static_lora(
18 in_dim: usize,
19 out_dim: usize,
20 loras: HashMap<String, StaticLoraConfig>,
21 vb: ShardedVarBuilder,
22) -> Result<Arc<dyn QuantMethod>> {
23 let layer = {
24 if !vb.contains_tensor("base_layer.weight") {
26 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
27 Arc::new(layer) as Arc<dyn QuantMethod>
28 } else {
29 let mut weight =
30 vb.get_with_hints((out_dim, in_dim), "base_layer.weight", Default::default())?;
31
32 for (name, lora_cfg) in loras {
33 let regex = Regex::new(&lora_cfg.layer).map_err(candle_core::Error::msg)?;
34 if !regex.is_match(&vb.prefix()) {
35 continue;
36 }
37
38 let a = vb.get((lora_cfg.r, in_dim), &format!("lora_A.{name}.weight"))?;
39 let b = vb.get((out_dim, lora_cfg.r), &format!("lora_B.{name}.weight"))?;
40 let scale = if lora_cfg.r > 0 {
41 lora_cfg.lora_alpha / lora_cfg.r as f64
42 } else {
43 1.0
44 };
45
46 let ab = if a.device().is_cpu() {
47 b.to_dtype(DType::F32)?.matmul(&a.to_dtype(DType::F32)?)?
48 } else {
49 b.matmul(&a)?
50 };
51
52 let delta_weight = (ab * scale)?;
53 weight = (weight + delta_weight.to_dtype(a.dtype())?)?;
54 }
55
56 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
57 Linear::new(weight, None),
58 ))?;
59 Arc::new(layer) as Arc<dyn QuantMethod>
60 }
61 };
62 Ok(layer)
63}