mistralrs_core/lora/
loralinear.rs

1use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc};
2
3use candle_core::{bail, DType, Module, Result, Tensor};
4use candle_nn::Linear;
5use either::Either;
6use mistralrs_quant::{QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear};
7
8use crate::layers::MatMul;
9
10use super::{
11    apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, AdapterSwapper,
12    LinearLayerLike, LoraConfig, LoraLinearConfig, Merge,
13};
14
15pub struct LoraLinear {
16    old: Arc<dyn QuantMethod>,
17    a_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
18    b_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
19    scale_adapters: Vec<f64>,
20    layer_n: usize,
21    merged: bool,
22    adapters: HashMap<String, Adapter>,
23}
24
25impl LoraLinear {
26    pub fn new(
27        old: &dyn LinearLayerLike,
28        linear_config: &LoraLinearConfig,
29        config: &[((String, String), LoraConfig)],
30        vb: &ShardedVarBuilder,
31        layer_n: usize,
32        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
33    ) -> Result<Self> {
34        let mut a_adapters = Vec::with_capacity(config.len());
35        let mut b_adapters = Vec::with_capacity(config.len());
36        let mut scale_adapters = Vec::with_capacity(config.len());
37        let a_vb = vb.pp("lora_A".to_string());
38        let b_vb = vb.pp("lora_B".to_string());
39        let mut state = None;
40        let mut all_same = true;
41        let mut adapters = HashMap::new();
42        for ((name_id, adapter_name), cfg) in config.iter() {
43            let a_pp = a_vb.pp(name_id);
44            let b_pp = b_vb.pp(name_id);
45            let adapter = make_adapter(a_pp, b_pp, cfg, linear_config)?;
46            a_adapters.push(adapter.a.clone());
47            b_adapters.push(adapter.b.clone());
48            scale_adapters.push(adapter.scale);
49            if state.is_some_and(|x| {
50                x == (
51                    cfg.rank,
52                    linear_config.in_features,
53                    linear_config.out_features,
54                    cfg.alpha,
55                    cfg.dropout,
56                )
57            }) || state.is_none()
58            {
59                state = Some((
60                    cfg.rank,
61                    linear_config.in_features,
62                    linear_config.out_features,
63                    cfg.alpha,
64                    cfg.dropout,
65                ));
66            } else {
67                all_same = false;
68            }
69            adapters.insert(adapter_name.clone(), adapter);
70        }
71
72        if let Some(preload_adapters) = preload_adapters {
73            all_same = false;
74            for (name, (vb, cfg)) in preload_adapters {
75                let a_vb = vb.set_prefix(a_vb.prefix());
76                let b_vb = vb.set_prefix(b_vb.prefix());
77                let adapter = make_adapter(a_vb, b_vb, cfg, linear_config)?;
78                adapters.insert(name.clone(), adapter);
79            }
80        }
81
82        if all_same {
83            let a_adapters_stack = Tensor::cat(
84                &a_adapters
85                    .iter()
86                    .map(|x| x.weight().unsqueeze(0))
87                    .collect::<Result<Vec<_>>>()?,
88                0,
89            )?;
90            let b_adapters_stack = Tensor::cat(
91                &b_adapters
92                    .iter()
93                    .map(|x| x.weight().unsqueeze(0))
94                    .collect::<Result<Vec<_>>>()?,
95                0,
96            )?;
97            let scale_adapters_t = Tensor::from_vec(
98                scale_adapters.clone(),
99                (scale_adapters.len(), 1, 1),
100                a_adapters_stack.device(),
101            )?
102            .to_dtype(a_adapters_stack.dtype())?;
103            let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?;
104            Ok(LoraLinear {
105                old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
106                    Linear::new(old.weight().clone(), old.bias().cloned()),
107                ))?),
108                a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)),
109                b_adapters: Either::Right((b_adapters_stack, b_adapters)),
110                scale_adapters,
111                layer_n,
112                merged: false,
113                adapters,
114            })
115        } else {
116            Ok(LoraLinear {
117                old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
118                    Linear::new(old.weight().clone(), old.bias().cloned()),
119                ))?),
120                a_adapters: Either::Left(a_adapters),
121                b_adapters: Either::Left(b_adapters),
122                scale_adapters,
123                layer_n,
124                merged: false,
125                adapters,
126            })
127        }
128    }
129}
130
131impl AdapterSwapper for LoraLinear {
132    fn _activate_adapters(&mut self, adapter_names: &[String]) -> Result<()> {
133        match (
134            &mut self.a_adapters,
135            &mut self.b_adapters,
136            &mut self.scale_adapters,
137        ) {
138            (Either::Left(a), Either::Left(b), s) => {
139                a.clear();
140                b.clear();
141                s.clear();
142                for adapter_name in adapter_names {
143                    let Adapter {
144                        a: a_w,
145                        b: b_w,
146                        scale,
147                    } = match self.adapters.get(adapter_name) {
148                        Some(a) => a,
149                        None => bail!("Cannot load adapter `{adapter_name}`."),
150                    };
151                    a.push(a_w.clone());
152                    b.push(b_w.clone());
153                    s.push(*scale);
154                }
155            }
156            _ => unreachable!("Adapters should not be stacked if new ones are being activated."),
157        }
158        Ok(())
159    }
160    fn can_load(&self) -> bool {
161        true
162    }
163}
164
165impl Merge for LoraLinear {
166    fn get_delta_weight(&self, adapter: usize) -> Result<Tensor> {
167        match (&self.a_adapters, &self.b_adapters) {
168            (Either::Left(a), Either::Left(b)) | (Either::Right((_, a)), Either::Right((_, b))) => {
169                let w_a = a[adapter].weight();
170                let w_b = b[adapter].weight();
171
172                MatMul.matmul(w_b, w_a)? * self.scale_adapters[adapter]
173            }
174            _ => unreachable!("Both adapters must be Either::Left or Either::Right."),
175        }
176    }
177
178    fn merge_weights(&mut self) -> Result<()> {
179        let mut w_base_layer: Option<Tensor> = None;
180        for adapter in 0..self.scale_adapters.len() {
181            if let Some(w_base_layer) = &mut w_base_layer {
182                *w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?;
183            } else {
184                w_base_layer = Some(self.get_delta_weight(adapter)?)
185            }
186        }
187        self.old
188            .add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?;
189        self.merged = true;
190        Ok(())
191    }
192}
193
194impl LinearLayerLike for LoraLinear {
195    fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
196        &mut self.old
197    }
198    fn bias(&self) -> Option<&Tensor> {
199        unreachable!()
200    }
201    fn weight(&self) -> &Tensor {
202        unreachable!()
203    }
204    fn quantized_act_type(&self) -> Option<DType> {
205        self.old.quantized_act_type()
206    }
207    fn lora_forward(
208        &self,
209        input: &Tensor,
210        scalings: Option<Tensor>,
211        global_scaling_weight: f64,
212        is_scaling_pass: Option<f64>,
213    ) -> Result<Tensor> {
214        let mut result = self.old.forward(input)?;
215
216        if self.merged {
217            return Ok(result);
218        }
219
220        if is_scaling_pass.is_some_and(|x| x == 0.) {
221            return Ok(result);
222        }
223
224        let scalings =
225            scalings.map(|scalings| get_maybe_topk_scalings(scalings, self.layer_n).unwrap());
226        if self.a_adapters.is_left()
227            || scalings
228                .as_ref()
229                .is_some_and(|scalings| scalings.dims3().unwrap().1 != 1)
230        {
231            let a_adapters = if self.a_adapters.is_right() {
232                self.a_adapters.as_ref().unwrap_right().1.clone()
233            } else {
234                self.a_adapters.as_ref().unwrap_left().clone()
235            };
236            let b_adapters = if self.b_adapters.is_right() {
237                self.b_adapters.as_ref().unwrap_right().1.clone()
238            } else {
239                self.b_adapters.as_ref().unwrap_left().clone()
240            };
241            //No fan_in_fan_out so no weight.transpose(0,1)
242            for (i, (adapter_a, (adapter_b, adapter_scale))) in
243                zip(a_adapters, zip(b_adapters, &self.scale_adapters)).enumerate()
244            {
245                let input_new = input.to_dtype(adapter_a.weight().dtype())?;
246                let input_new = if let Some(scalings) = &scalings {
247                    apply_scalings_to_x(input_new, scalings, i)?
248                } else {
249                    input_new
250                };
251
252                let res = adapter_b
253                    .forward(&adapter_a.forward(&input_new)?)?
254                    .mul(*adapter_scale)?
255                    .mul(global_scaling_weight)?;
256                result = (result + res)?;
257            }
258            Ok(result)
259        } else {
260            let adapter_a = &self.a_adapters.as_ref().unwrap_right().0;
261            let adapter_b = &self.b_adapters.as_ref().unwrap_right().0;
262            let adapter_scales = &self.scale_adapters;
263            let n_adapters = adapter_scales.len();
264            let adapter_a = if let Some(scalings) = scalings.as_ref() {
265                let scalings = scalings
266                    .squeeze(0)?
267                    .squeeze(0)?
268                    .unsqueeze(1)?
269                    .unsqueeze(1)?;
270                adapter_a
271                    .broadcast_mul(&scalings)?
272                    .mul(global_scaling_weight)?
273            } else {
274                adapter_a.clone().mul(global_scaling_weight)?
275            };
276
277            let (b, s, h) = input.dims3()?;
278            let input = input.reshape((b * s, h))?;
279            let out = adapter_a.broadcast_matmul(&input.t()?)?;
280            let out = adapter_b.broadcast_matmul(&out)?;
281            let o_h = out.dims()[1];
282            let out = out.reshape((n_adapters, b, s, o_h))?;
283            let out = out.sum(0)?;
284            out + result
285        }
286    }
287    fn is_lora(&self) -> bool {
288        !self.adapters.is_empty()
289    }
290}