mistralrs_core/amoe/
macros.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#[macro_export]
#[doc(hidden)]
macro_rules! get_delta_from_lora_ab {
    ($vb_mlp:expr, $rank:expr, $alpha:expr, ($in_d:expr, $out_d:expr), $name:expr) => {{
        let proj_a = $vb_mlp
            .pp($name)
            .pp("lora_A")
            .get(($rank, $in_d), "weight")?;
        let proj_b = $vb_mlp
            .pp($name)
            .pp("lora_B")
            .get(($out_d, $rank), "weight")?;
        let scale = if $rank > 0 {
            $alpha / $rank as f64
        } else {
            1.0
        };
        (proj_b.matmul(&proj_a)? * scale)?
    }};
}

#[macro_export]
#[doc(hidden)]
macro_rules! merge_delta {
    ($qmatmul:expr, $delta:expr) => {
        match &$qmatmul {
            QMatMul::Tensor(w) => QMatMul::Tensor((w + $delta)?),
            QMatMul::TensorF16(w) => QMatMul::TensorF16((w + $delta)?),
            QMatMul::QTensor(w) => {
                let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
                QMatMul::QTensor(std::sync::Arc::new(
                    candle_core::quantized::QTensor::quantize(&(w + $delta)?, dtype)?,
                ))
            }
        }
    };
}