mistralrs_core/amoe/
macros.rs

1#[macro_export]
2#[doc(hidden)]
3macro_rules! get_delta_from_lora_ab {
4    ($vb_mlp:expr, $rank:expr, $alpha:expr, ($in_d:expr, $out_d:expr), $name:expr) => {{
5        let proj_a = $vb_mlp
6            .pp($name)
7            .pp("lora_A")
8            .get(($rank, $in_d), "weight")?;
9        let proj_b = $vb_mlp
10            .pp($name)
11            .pp("lora_B")
12            .get(($out_d, $rank), "weight")?;
13        let scale = if $rank > 0 {
14            $alpha / $rank as f64
15        } else {
16            1.0
17        };
18        (MatMul.matmul(&proj_b, &proj_a)? * scale)?
19    }};
20}
21
22#[macro_export]
23#[doc(hidden)]
24macro_rules! merge_delta {
25    ($qmatmul:expr, $delta:expr) => {
26        match &$qmatmul {
27            QMatMul::Tensor(w) => QMatMul::Tensor((w + $delta)?),
28            QMatMul::TensorF16(w) => QMatMul::TensorF16((w + $delta)?),
29            QMatMul::QTensor(w) => {
30                let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
31                QMatMul::QTensor(std::sync::Arc::new(
32                    candle_core::quantized::QTensor::quantize(&(w + $delta)?, dtype)?,
33                ))
34            }
35        }
36    };
37}