mistralrs_quant/lora/
static_lora.rs

1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{DType, Result};
4use candle_nn::Linear;
5use regex::Regex;
6
7use crate::{DummyLayer, QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear};
8
9use super::StaticLoraConfig;
10
11/// Static LoRA in the style of Phi-4 multimodal. Only when the layer regex for the specific LoRA matches.
12///
13/// Structure:
14/// - prefix.base_layer.weight
15/// - prefix.lora_A.<lora name>.weight
16/// - prefix.lora_B.<lora name>.weight
17pub fn linear_no_bias_static_lora(
18    in_dim: usize,
19    out_dim: usize,
20    loras: HashMap<String, StaticLoraConfig>,
21    vb: ShardedVarBuilder,
22) -> Result<Arc<dyn QuantMethod>> {
23    let layer = {
24        // Handle the case where the layer is dummy (no tensors)
25        if !vb.contains_tensor("base_layer.weight") {
26            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
27            Arc::new(layer) as Arc<dyn QuantMethod>
28        } else {
29            let mut weight =
30                vb.get_with_hints((out_dim, in_dim), "base_layer.weight", Default::default())?;
31
32            for (name, lora_cfg) in loras {
33                let regex = Regex::new(&lora_cfg.layer).map_err(candle_core::Error::msg)?;
34                if !regex.is_match(&vb.prefix()) {
35                    continue;
36                }
37
38                let a = vb.get((lora_cfg.r, in_dim), &format!("lora_A.{name}.weight"))?;
39                let b = vb.get((out_dim, lora_cfg.r), &format!("lora_B.{name}.weight"))?;
40                let scale = if lora_cfg.r > 0 {
41                    lora_cfg.lora_alpha / lora_cfg.r as f64
42                } else {
43                    1.0
44                };
45
46                let ab = if a.device().is_cpu() {
47                    b.to_dtype(DType::F32)?.matmul(&a.to_dtype(DType::F32)?)?
48                } else {
49                    b.matmul(&a)?
50                };
51
52                let delta_weight = (ab * scale)?;
53                weight = (weight + delta_weight.to_dtype(a.dtype())?)?;
54            }
55
56            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
57                Linear::new(weight, None),
58            ))?;
59            Arc::new(layer) as Arc<dyn QuantMethod>
60        }
61    };
62    Ok(layer)
63}