mistralrs_core/amoe/
macros.rs
1#[macro_export]
2#[doc(hidden)]
3macro_rules! get_delta_from_lora_ab {
4 ($vb_mlp:expr, $rank:expr, $alpha:expr, ($in_d:expr, $out_d:expr), $name:expr) => {{
5 let proj_a = $vb_mlp
6 .pp($name)
7 .pp("lora_A")
8 .get(($rank, $in_d), "weight")?;
9 let proj_b = $vb_mlp
10 .pp($name)
11 .pp("lora_B")
12 .get(($out_d, $rank), "weight")?;
13 let scale = if $rank > 0 {
14 $alpha / $rank as f64
15 } else {
16 1.0
17 };
18 (MatMul.matmul(&proj_b, &proj_a)? * scale)?
19 }};
20}
21
22#[macro_export]
23#[doc(hidden)]
24macro_rules! merge_delta {
25 ($qmatmul:expr, $delta:expr) => {
26 match &$qmatmul {
27 QMatMul::Tensor(w) => QMatMul::Tensor((w + $delta)?),
28 QMatMul::TensorF16(w) => QMatMul::TensorF16((w + $delta)?),
29 QMatMul::QTensor(w) => {
30 let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
31 QMatMul::QTensor(std::sync::Arc::new(
32 candle_core::quantized::QTensor::quantize(&(w + $delta)?, dtype)?,
33 ))
34 }
35 }
36 };
37}