mistralrs_core/lora/
qloralinear.rs

1use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc};
2
3use candle_core::{bail, quantized::QMatMul, DType, Module, Result, Tensor};
4use candle_nn::Linear;
5use either::Either;
6use mistralrs_quant::{
7    GgufMatMul, QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear,
8};
9
10use crate::layers::MatMul;
11
12use super::{
13    apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, AdapterSwapper,
14    LinearLayerLike, LoraConfig, LoraLinearConfig, Merge, Ordering,
15};
16
17#[derive(Debug)]
18pub struct QLoraLinear {
19    old: Arc<dyn QuantMethod>,
20    a_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
21    b_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
22    scale_adapters: Vec<f64>,
23    layer_n: usize,
24    merged: bool,
25    adapters: HashMap<String, Adapter>,
26    linear_config: Option<LoraLinearConfig>,
27}
28
29/// Specialized QLoRA for no bias
30impl QLoraLinear {
31    #[allow(clippy::too_many_arguments)]
32    pub fn new(
33        old: QMatMul,
34        linear_config: &LoraLinearConfig,
35        config: &[((String, String), LoraConfig)],
36        vb: &ShardedVarBuilder,
37        ordering: &Ordering,
38        prefix: String,
39        count: &mut usize,
40        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
41    ) -> Result<Self> {
42        let target_modules = &config.first().map(|c| &c.1.target_modules);
43        for (_, cfg) in config {
44            if target_modules
45                .as_ref()
46                .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
47            {
48                candle_core::bail!("Expected all target modules to be the same.");
49            }
50        }
51
52        let old: Arc<dyn QuantMethod> = match old {
53            QMatMul::QTensor(q) => Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
54                q_weight: q,
55                b: None,
56            })?),
57            QMatMul::TensorF16(t) | QMatMul::Tensor(t) => Arc::new(UnquantLinear::new(
58                QuantMethodConfig::Unquantized(Linear::new(t, None)),
59            )?),
60        };
61
62        let module = prefix.split('.').next_back().unwrap();
63        if target_modules.is_some_and(|target_modules| !target_modules.contains(module)) {
64            return Ok(Self {
65                old,
66                a_adapters: Either::Left(vec![]),
67                b_adapters: Either::Left(vec![]),
68                scale_adapters: vec![],
69                layer_n: usize::MAX,
70                merged: false,
71                adapters: HashMap::default(),
72                linear_config: None,
73            });
74        }
75
76        *count += 1;
77
78        let mut a_adapters = Vec::with_capacity(config.len());
79        let mut b_adapters = Vec::with_capacity(config.len());
80        let mut scale_adapters = Vec::with_capacity(config.len());
81        let vb = vb.pp(prefix.clone());
82        let a_vb = vb.pp("lora_A".to_string());
83        let b_vb = vb.pp("lora_B".to_string());
84        let mut state = None;
85        let mut all_same = true;
86        let mut adapters = HashMap::new();
87        for ((name_id, adapter_name), cfg) in config.iter() {
88            let a_pp = a_vb.pp(name_id);
89            let b_pp = b_vb.pp(name_id);
90            let adapter = make_adapter(a_pp, b_pp, cfg, linear_config)?;
91            a_adapters.push(adapter.a.clone());
92            b_adapters.push(adapter.b.clone());
93            scale_adapters.push(adapter.scale);
94            if state.is_some_and(|x| {
95                x == (
96                    cfg.rank,
97                    linear_config.in_features,
98                    linear_config.out_features,
99                    cfg.alpha,
100                    cfg.dropout,
101                )
102            }) || state.is_none()
103            {
104                state = Some((
105                    cfg.rank,
106                    linear_config.in_features,
107                    linear_config.out_features,
108                    cfg.alpha,
109                    cfg.dropout,
110                ));
111            } else {
112                all_same = false;
113            }
114            adapters.insert(adapter_name.clone(), adapter);
115        }
116
117        if let Some(preload_adapters) = preload_adapters {
118            all_same = false;
119            for (name, (vb, cfg)) in preload_adapters {
120                let a_vb = vb.set_prefix(a_vb.prefix());
121                let b_vb = vb.set_prefix(b_vb.prefix());
122                let adapter = make_adapter(a_vb, b_vb, cfg, linear_config)?;
123                adapters.insert(name.clone(), adapter);
124            }
125        }
126
127        let layer = if let Some(ref layers) = ordering.layers {
128            *layers.get(&prefix).unwrap()
129        } else {
130            0
131        };
132
133        if all_same {
134            let a_adapters_stack = Tensor::cat(
135                &a_adapters
136                    .iter()
137                    .map(|x| x.weight().unsqueeze(0))
138                    .collect::<Result<Vec<_>>>()?,
139                0,
140            )?;
141            let b_adapters_stack = Tensor::cat(
142                &b_adapters
143                    .iter()
144                    .map(|x| x.weight().unsqueeze(0))
145                    .collect::<Result<Vec<_>>>()?,
146                0,
147            )?;
148            let scale_adapters_t = Tensor::from_vec(
149                scale_adapters.clone(),
150                (scale_adapters.len(), 1, 1),
151                a_adapters_stack.device(),
152            )?
153            .to_dtype(a_adapters_stack.dtype())?;
154            let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?;
155            Ok(QLoraLinear {
156                old,
157                a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)),
158                b_adapters: Either::Right((b_adapters_stack.clone(), b_adapters)),
159                scale_adapters,
160                layer_n: layer,
161                merged: false,
162                adapters,
163                linear_config: Some(linear_config.clone()),
164            })
165        } else {
166            Ok(QLoraLinear {
167                old,
168                a_adapters: Either::Left(a_adapters),
169                b_adapters: Either::Left(b_adapters),
170                scale_adapters,
171                layer_n: layer,
172                merged: false,
173                adapters,
174                linear_config: Some(linear_config.clone()),
175            })
176        }
177    }
178}
179
180impl AdapterSwapper for QLoraLinear {
181    fn _activate_adapters(&mut self, adapter_names: &[String]) -> Result<()> {
182        match (
183            &mut self.a_adapters,
184            &mut self.b_adapters,
185            &mut self.scale_adapters,
186        ) {
187            (Either::Left(a), Either::Left(b), s) => {
188                a.clear();
189                b.clear();
190                s.clear();
191                for adapter_name in adapter_names {
192                    let Adapter {
193                        a: a_w,
194                        b: b_w,
195                        scale,
196                    } = match self.adapters.get(adapter_name) {
197                        Some(a) => a,
198                        None => bail!("Cannot load adapter `{adapter_name}`."),
199                    };
200                    a.push(a_w.clone());
201                    b.push(b_w.clone());
202                    s.push(*scale);
203                }
204            }
205            _ => unreachable!("Adapters should not be stacked if new ones are being activated."),
206        }
207        Ok(())
208    }
209    fn can_load(&self) -> bool {
210        self.linear_config.is_some()
211    }
212}
213
214impl Merge for QLoraLinear {
215    fn get_delta_weight(&self, adapter: usize) -> Result<Tensor> {
216        match (&self.a_adapters, &self.b_adapters) {
217            (Either::Left(a), Either::Left(b)) | (Either::Right((_, a)), Either::Right((_, b))) => {
218                let w_a = a[adapter].weight();
219                let w_b = b[adapter].weight();
220
221                MatMul.matmul(w_b, w_a)? * self.scale_adapters[adapter]
222            }
223            _ => unreachable!("Both adapters must be Either::Left or Either::Right."),
224        }
225    }
226
227    fn merge_weights(&mut self) -> Result<()> {
228        let mut w_base_layer: Option<Tensor> = None;
229        for adapter in 0..self.scale_adapters.len() {
230            if let Some(w_base_layer) = &mut w_base_layer {
231                *w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?;
232            } else {
233                w_base_layer = Some(self.get_delta_weight(adapter)?)
234            }
235        }
236        self.old
237            .add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?;
238        self.merged = true;
239        Ok(())
240    }
241}
242
243impl LinearLayerLike for QLoraLinear {
244    fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
245        &mut self.old
246    }
247    fn bias(&self) -> Option<&Tensor> {
248        None
249    }
250    fn weight(&self) -> &Tensor {
251        unimplemented!()
252    }
253    fn quantized_act_type(&self) -> Option<DType> {
254        Some(DType::F32)
255    }
256    fn lora_forward(
257        &self,
258        input: &Tensor,
259        scalings: Option<Tensor>,
260        global_scaling_weight: f64,
261        is_scaling_pass: Option<f64>,
262    ) -> Result<Tensor> {
263        //No fan_in_fan_out so no weight.transpose(0,1)
264        let mut result = self.old.forward(input)?;
265        if self.merged {
266            return Ok(result);
267        }
268
269        if self
270            .a_adapters
271            .as_ref()
272            .left()
273            .is_some_and(|x| x.is_empty())
274            || (is_scaling_pass.is_some_and(|x| x == 0.))
275        {
276            return Ok(result);
277        }
278
279        let scalings =
280            scalings.map(|scalings| get_maybe_topk_scalings(scalings, self.layer_n).unwrap());
281        if self.a_adapters.is_left()
282            || scalings
283                .as_ref()
284                .is_some_and(|scalings| scalings.dims3().unwrap().1 != 1)
285        {
286            let a_adapters = if self.a_adapters.is_right() {
287                self.a_adapters.as_ref().unwrap_right().1.clone()
288            } else {
289                self.a_adapters.as_ref().unwrap_left().clone()
290            };
291            let b_adapters = if self.b_adapters.is_right() {
292                self.b_adapters.as_ref().unwrap_right().1.clone()
293            } else {
294                self.b_adapters.as_ref().unwrap_left().clone()
295            };
296            for (i, (adapter_a, (adapter_b, adapter_scale))) in
297                zip(a_adapters, zip(b_adapters, &self.scale_adapters)).enumerate()
298            {
299                let input_new = if let Some(scalings) = &scalings {
300                    apply_scalings_to_x(input.clone(), scalings, i)?
301                } else {
302                    input.clone()
303                };
304
305                let res = adapter_b
306                    .forward(&adapter_a.forward(&input_new)?)?
307                    .mul(*adapter_scale)?
308                    .mul(global_scaling_weight)?;
309                result = (result + res)?;
310            }
311            Ok(result)
312        } else {
313            let adapter_a = &self.a_adapters.as_ref().unwrap_right().0;
314            let adapter_b = &self.b_adapters.as_ref().unwrap_right().0;
315            let adapter_scales = &self.scale_adapters;
316            let n_adapters = adapter_scales.len();
317            let adapter_a = if let Some(scalings) = scalings.as_ref() {
318                let scalings = scalings
319                    .squeeze(0)?
320                    .squeeze(0)?
321                    .unsqueeze(1)?
322                    .unsqueeze(1)?;
323                adapter_a
324                    .broadcast_mul(&scalings)?
325                    .mul(global_scaling_weight)?
326            } else {
327                adapter_a.clone().mul(global_scaling_weight)?
328            };
329
330            let (b, s, h) = input.dims3()?;
331            let input = input.reshape((b * s, h))?;
332            let out = adapter_a.broadcast_matmul(&input.t()?)?;
333            let out = adapter_b.broadcast_matmul(&out)?;
334            let o_h = out.dims()[1];
335            let out = out.reshape((n_adapters, b, s, o_h))?;
336            let out = out.sum(0)?;
337            out + result
338        }
339    }
340    fn is_lora(&self) -> bool {
341        !self.adapters.is_empty()
342    }
343}