mistralrs_core/lora/
loralinear.rs

1use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc};
2
3use candle_core::{DType, Module, Result, Tensor};
4use candle_nn::Linear;
5use either::Either;
6use mistralrs_quant::{QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear};
7
8use crate::layers::MatMul;
9
10use super::{
11    apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, LinearLayerLike,
12    LoraConfig, LoraLinearConfig, Merge,
13};
14
15pub struct LoraLinear {
16    old: Arc<dyn QuantMethod>,
17    a_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
18    b_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
19    scale_adapters: Vec<f64>,
20    layer_n: usize,
21    merged: bool,
22    adapters: HashMap<String, Adapter>,
23}
24
25impl LoraLinear {
26    pub fn new(
27        old: &dyn LinearLayerLike,
28        linear_config: &LoraLinearConfig,
29        config: &[((String, String), LoraConfig)],
30        vb: &ShardedVarBuilder,
31        layer_n: usize,
32        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
33    ) -> Result<Self> {
34        let mut a_adapters = Vec::with_capacity(config.len());
35        let mut b_adapters = Vec::with_capacity(config.len());
36        let mut scale_adapters = Vec::with_capacity(config.len());
37        let a_vb = vb.pp("lora_A".to_string());
38        let b_vb = vb.pp("lora_B".to_string());
39        let mut state = None;
40        let mut all_same = true;
41        let mut adapters = HashMap::new();
42        for ((name_id, adapter_name), cfg) in config.iter() {
43            let a_pp = a_vb.pp(name_id);
44            let b_pp = b_vb.pp(name_id);
45            let adapter = make_adapter(a_pp, b_pp, cfg, linear_config)?;
46            a_adapters.push(adapter.a.clone());
47            b_adapters.push(adapter.b.clone());
48            scale_adapters.push(adapter.scale);
49            if state.is_some_and(|x| {
50                x == (
51                    cfg.rank,
52                    linear_config.in_features,
53                    linear_config.out_features,
54                    cfg.alpha,
55                    cfg.dropout,
56                )
57            }) || state.is_none()
58            {
59                state = Some((
60                    cfg.rank,
61                    linear_config.in_features,
62                    linear_config.out_features,
63                    cfg.alpha,
64                    cfg.dropout,
65                ));
66            } else {
67                all_same = false;
68            }
69            adapters.insert(adapter_name.clone(), adapter);
70        }
71
72        if let Some(preload_adapters) = preload_adapters {
73            all_same = false;
74            for (name, (vb, cfg)) in preload_adapters {
75                let a_vb = vb.set_prefix(a_vb.prefix());
76                let b_vb = vb.set_prefix(b_vb.prefix());
77                let adapter = make_adapter(a_vb, b_vb, cfg, linear_config)?;
78                adapters.insert(name.clone(), adapter);
79            }
80        }
81
82        if all_same {
83            let a_adapters_stack = Tensor::cat(
84                &a_adapters
85                    .iter()
86                    .map(|x| x.weight().unsqueeze(0))
87                    .collect::<Result<Vec<_>>>()?,
88                0,
89            )?;
90            let b_adapters_stack = Tensor::cat(
91                &b_adapters
92                    .iter()
93                    .map(|x| x.weight().unsqueeze(0))
94                    .collect::<Result<Vec<_>>>()?,
95                0,
96            )?;
97            let scale_adapters_t = Tensor::from_vec(
98                scale_adapters.clone(),
99                (scale_adapters.len(), 1, 1),
100                a_adapters_stack.device(),
101            )?
102            .to_dtype(a_adapters_stack.dtype())?;
103            let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?;
104            Ok(LoraLinear {
105                old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
106                    Linear::new(old.weight().clone(), old.bias().cloned()),
107                ))?),
108                a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)),
109                b_adapters: Either::Right((b_adapters_stack, b_adapters)),
110                scale_adapters,
111                layer_n,
112                merged: false,
113                adapters,
114            })
115        } else {
116            Ok(LoraLinear {
117                old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
118                    Linear::new(old.weight().clone(), old.bias().cloned()),
119                ))?),
120                a_adapters: Either::Left(a_adapters),
121                b_adapters: Either::Left(b_adapters),
122                scale_adapters,
123                layer_n,
124                merged: false,
125                adapters,
126            })
127        }
128    }
129}
130
131impl Merge for LoraLinear {
132    fn get_delta_weight(&self, adapter: usize) -> Result<Tensor> {
133        match (&self.a_adapters, &self.b_adapters) {
134            (Either::Left(a), Either::Left(b)) | (Either::Right((_, a)), Either::Right((_, b))) => {
135                let w_a = a[adapter].weight();
136                let w_b = b[adapter].weight();
137
138                MatMul.matmul(w_b, w_a)? * self.scale_adapters[adapter]
139            }
140            _ => unreachable!("Both adapters must be Either::Left or Either::Right."),
141        }
142    }
143
144    fn merge_weights(&mut self) -> Result<()> {
145        let mut w_base_layer: Option<Tensor> = None;
146        for adapter in 0..self.scale_adapters.len() {
147            if let Some(w_base_layer) = &mut w_base_layer {
148                *w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?;
149            } else {
150                w_base_layer = Some(self.get_delta_weight(adapter)?)
151            }
152        }
153        self.old
154            .add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?;
155        self.merged = true;
156        Ok(())
157    }
158}
159
160impl LinearLayerLike for LoraLinear {
161    fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
162        &mut self.old
163    }
164    fn bias(&self) -> Option<&Tensor> {
165        unreachable!()
166    }
167    fn weight(&self) -> &Tensor {
168        unreachable!()
169    }
170    fn quantized_act_type(&self) -> Option<DType> {
171        self.old.quantized_act_type()
172    }
173    fn lora_forward(
174        &self,
175        input: &Tensor,
176        scalings: Option<Tensor>,
177        global_scaling_weight: f64,
178        is_scaling_pass: Option<f64>,
179    ) -> Result<Tensor> {
180        let mut result = self.old.forward(input)?;
181
182        if self.merged {
183            return Ok(result);
184        }
185
186        if is_scaling_pass.is_some_and(|x| x == 0.) {
187            return Ok(result);
188        }
189
190        let scalings =
191            scalings.map(|scalings| get_maybe_topk_scalings(scalings, self.layer_n).unwrap());
192        if self.a_adapters.is_left()
193            || scalings
194                .as_ref()
195                .is_some_and(|scalings| scalings.dims3().unwrap().1 != 1)
196        {
197            let a_adapters = if self.a_adapters.is_right() {
198                self.a_adapters.as_ref().unwrap_right().1.clone()
199            } else {
200                self.a_adapters.as_ref().unwrap_left().clone()
201            };
202            let b_adapters = if self.b_adapters.is_right() {
203                self.b_adapters.as_ref().unwrap_right().1.clone()
204            } else {
205                self.b_adapters.as_ref().unwrap_left().clone()
206            };
207            //No fan_in_fan_out so no weight.transpose(0,1)
208            for (i, (adapter_a, (adapter_b, adapter_scale))) in
209                zip(a_adapters, zip(b_adapters, &self.scale_adapters)).enumerate()
210            {
211                let input_new = input.to_dtype(adapter_a.weight().dtype())?;
212                let input_new = if let Some(scalings) = &scalings {
213                    apply_scalings_to_x(input_new, scalings, i)?
214                } else {
215                    input_new
216                };
217
218                let res = adapter_b
219                    .forward(&adapter_a.forward(&input_new)?)?
220                    .mul(*adapter_scale)?
221                    .mul(global_scaling_weight)?;
222                result = (result + res)?;
223            }
224            Ok(result)
225        } else {
226            let adapter_a = &self.a_adapters.as_ref().unwrap_right().0;
227            let adapter_b = &self.b_adapters.as_ref().unwrap_right().0;
228            let adapter_scales = &self.scale_adapters;
229            let n_adapters = adapter_scales.len();
230            let adapter_a = if let Some(scalings) = scalings.as_ref() {
231                let scalings = scalings
232                    .squeeze(0)?
233                    .squeeze(0)?
234                    .unsqueeze(1)?
235                    .unsqueeze(1)?;
236                adapter_a
237                    .broadcast_mul(&scalings)?
238                    .mul(global_scaling_weight)?
239            } else {
240                adapter_a.clone().mul(global_scaling_weight)?
241            };
242
243            let (b, s, h) = input.dims3()?;
244            let input = input.reshape((b * s, h))?;
245            let out = adapter_a.broadcast_matmul(&input.t()?)?;
246            let out = adapter_b.broadcast_matmul(&out)?;
247            let o_h = out.dims()[1];
248            let out = out.reshape((n_adapters, b, s, o_h))?;
249            let out = out.sum(0)?;
250            out + result
251        }
252    }
253    fn is_lora(&self) -> bool {
254        !self.adapters.is_empty()
255    }
256}