mistralrs_core/lora/
qloralinear.rs

1use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc};
2
3use candle_core::{quantized::QMatMul, DType, Module, Result, Tensor};
4use candle_nn::Linear;
5use either::Either;
6use mistralrs_quant::{
7    GgufMatMul, QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear,
8};
9
10use crate::layers::MatMul;
11
12use super::{
13    apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, LinearLayerLike,
14    LoraConfig, LoraLinearConfig, Merge, Ordering,
15};
16
17#[derive(Debug)]
18pub struct QLoraLinear {
19    old: Arc<dyn QuantMethod>,
20    a_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
21    b_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
22    scale_adapters: Vec<f64>,
23    layer_n: usize,
24    merged: bool,
25    adapters: HashMap<String, Adapter>,
26}
27
28/// Specialized QLoRA for no bias
29impl QLoraLinear {
30    #[allow(clippy::too_many_arguments)]
31    pub fn new(
32        old: QMatMul,
33        linear_config: &LoraLinearConfig,
34        config: &[((String, String), LoraConfig)],
35        vb: &ShardedVarBuilder,
36        ordering: &Ordering,
37        prefix: String,
38        count: &mut usize,
39        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
40    ) -> Result<Self> {
41        let target_modules = &config.first().map(|c| &c.1.target_modules);
42        for (_, cfg) in config {
43            if target_modules
44                .as_ref()
45                .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
46            {
47                candle_core::bail!("Expected all target modules to be the same.");
48            }
49        }
50
51        let old: Arc<dyn QuantMethod> = match old {
52            QMatMul::QTensor(q) => Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
53                q_weight: q,
54                b: None,
55            })?),
56            QMatMul::TensorF16(t) | QMatMul::Tensor(t) => Arc::new(UnquantLinear::new(
57                QuantMethodConfig::Unquantized(Linear::new(t, None)),
58            )?),
59        };
60
61        let module = prefix.split('.').next_back().unwrap();
62        if target_modules.is_some_and(|target_modules| !target_modules.contains(module)) {
63            return Ok(Self {
64                old,
65                a_adapters: Either::Left(vec![]),
66                b_adapters: Either::Left(vec![]),
67                scale_adapters: vec![],
68                layer_n: usize::MAX,
69                merged: false,
70                adapters: HashMap::default(),
71            });
72        }
73
74        *count += 1;
75
76        let mut a_adapters = Vec::with_capacity(config.len());
77        let mut b_adapters = Vec::with_capacity(config.len());
78        let mut scale_adapters = Vec::with_capacity(config.len());
79        let vb = vb.pp(prefix.clone());
80        let a_vb = vb.pp("lora_A".to_string());
81        let b_vb = vb.pp("lora_B".to_string());
82        let mut state = None;
83        let mut all_same = true;
84        let mut adapters = HashMap::new();
85        for ((name_id, adapter_name), cfg) in config.iter() {
86            let a_pp = a_vb.pp(name_id);
87            let b_pp = b_vb.pp(name_id);
88            let adapter = make_adapter(a_pp, b_pp, cfg, linear_config)?;
89            a_adapters.push(adapter.a.clone());
90            b_adapters.push(adapter.b.clone());
91            scale_adapters.push(adapter.scale);
92            if state.is_some_and(|x| {
93                x == (
94                    cfg.rank,
95                    linear_config.in_features,
96                    linear_config.out_features,
97                    cfg.alpha,
98                    cfg.dropout,
99                )
100            }) || state.is_none()
101            {
102                state = Some((
103                    cfg.rank,
104                    linear_config.in_features,
105                    linear_config.out_features,
106                    cfg.alpha,
107                    cfg.dropout,
108                ));
109            } else {
110                all_same = false;
111            }
112            adapters.insert(adapter_name.clone(), adapter);
113        }
114
115        if let Some(preload_adapters) = preload_adapters {
116            all_same = false;
117            for (name, (vb, cfg)) in preload_adapters {
118                let a_vb = vb.set_prefix(a_vb.prefix());
119                let b_vb = vb.set_prefix(b_vb.prefix());
120                let adapter = make_adapter(a_vb, b_vb, cfg, linear_config)?;
121                adapters.insert(name.clone(), adapter);
122            }
123        }
124
125        let layer = if let Some(ref layers) = ordering.layers {
126            *layers.get(&prefix).unwrap()
127        } else {
128            0
129        };
130
131        if all_same {
132            let a_adapters_stack = Tensor::cat(
133                &a_adapters
134                    .iter()
135                    .map(|x| x.weight().unsqueeze(0))
136                    .collect::<Result<Vec<_>>>()?,
137                0,
138            )?;
139            let b_adapters_stack = Tensor::cat(
140                &b_adapters
141                    .iter()
142                    .map(|x| x.weight().unsqueeze(0))
143                    .collect::<Result<Vec<_>>>()?,
144                0,
145            )?;
146            let scale_adapters_t = Tensor::from_vec(
147                scale_adapters.clone(),
148                (scale_adapters.len(), 1, 1),
149                a_adapters_stack.device(),
150            )?
151            .to_dtype(a_adapters_stack.dtype())?;
152            let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?;
153            Ok(QLoraLinear {
154                old,
155                a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)),
156                b_adapters: Either::Right((b_adapters_stack.clone(), b_adapters)),
157                scale_adapters,
158                layer_n: layer,
159                merged: false,
160                adapters,
161            })
162        } else {
163            Ok(QLoraLinear {
164                old,
165                a_adapters: Either::Left(a_adapters),
166                b_adapters: Either::Left(b_adapters),
167                scale_adapters,
168                layer_n: layer,
169                merged: false,
170                adapters,
171            })
172        }
173    }
174}
175
176impl Merge for QLoraLinear {
177    fn get_delta_weight(&self, adapter: usize) -> Result<Tensor> {
178        match (&self.a_adapters, &self.b_adapters) {
179            (Either::Left(a), Either::Left(b)) | (Either::Right((_, a)), Either::Right((_, b))) => {
180                let w_a = a[adapter].weight();
181                let w_b = b[adapter].weight();
182
183                MatMul.matmul(w_b, w_a)? * self.scale_adapters[adapter]
184            }
185            _ => unreachable!("Both adapters must be Either::Left or Either::Right."),
186        }
187    }
188
189    fn merge_weights(&mut self) -> Result<()> {
190        let mut w_base_layer: Option<Tensor> = None;
191        for adapter in 0..self.scale_adapters.len() {
192            if let Some(w_base_layer) = &mut w_base_layer {
193                *w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?;
194            } else {
195                w_base_layer = Some(self.get_delta_weight(adapter)?)
196            }
197        }
198        self.old
199            .add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?;
200        self.merged = true;
201        Ok(())
202    }
203}
204
205impl LinearLayerLike for QLoraLinear {
206    fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
207        &mut self.old
208    }
209    fn bias(&self) -> Option<&Tensor> {
210        None
211    }
212    fn weight(&self) -> &Tensor {
213        unimplemented!()
214    }
215    fn quantized_act_type(&self) -> Option<DType> {
216        Some(DType::F32)
217    }
218    fn lora_forward(
219        &self,
220        input: &Tensor,
221        scalings: Option<Tensor>,
222        global_scaling_weight: f64,
223        is_scaling_pass: Option<f64>,
224    ) -> Result<Tensor> {
225        //No fan_in_fan_out so no weight.transpose(0,1)
226        let mut result = self.old.forward(input)?;
227        if self.merged {
228            return Ok(result);
229        }
230
231        if self
232            .a_adapters
233            .as_ref()
234            .left()
235            .is_some_and(|x| x.is_empty())
236            || (is_scaling_pass.is_some_and(|x| x == 0.))
237        {
238            return Ok(result);
239        }
240
241        let scalings =
242            scalings.map(|scalings| get_maybe_topk_scalings(scalings, self.layer_n).unwrap());
243        if self.a_adapters.is_left()
244            || scalings
245                .as_ref()
246                .is_some_and(|scalings| scalings.dims3().unwrap().1 != 1)
247        {
248            let a_adapters = if self.a_adapters.is_right() {
249                self.a_adapters.as_ref().unwrap_right().1.clone()
250            } else {
251                self.a_adapters.as_ref().unwrap_left().clone()
252            };
253            let b_adapters = if self.b_adapters.is_right() {
254                self.b_adapters.as_ref().unwrap_right().1.clone()
255            } else {
256                self.b_adapters.as_ref().unwrap_left().clone()
257            };
258            for (i, (adapter_a, (adapter_b, adapter_scale))) in
259                zip(a_adapters, zip(b_adapters, &self.scale_adapters)).enumerate()
260            {
261                let input_new = if let Some(scalings) = &scalings {
262                    apply_scalings_to_x(input.clone(), scalings, i)?
263                } else {
264                    input.clone()
265                };
266
267                let res = adapter_b
268                    .forward(&adapter_a.forward(&input_new)?)?
269                    .mul(*adapter_scale)?
270                    .mul(global_scaling_weight)?;
271                result = (result + res)?;
272            }
273            Ok(result)
274        } else {
275            let adapter_a = &self.a_adapters.as_ref().unwrap_right().0;
276            let adapter_b = &self.b_adapters.as_ref().unwrap_right().0;
277            let adapter_scales = &self.scale_adapters;
278            let n_adapters = adapter_scales.len();
279            let adapter_a = if let Some(scalings) = scalings.as_ref() {
280                let scalings = scalings
281                    .squeeze(0)?
282                    .squeeze(0)?
283                    .unsqueeze(1)?
284                    .unsqueeze(1)?;
285                adapter_a
286                    .broadcast_mul(&scalings)?
287                    .mul(global_scaling_weight)?
288            } else {
289                adapter_a.clone().mul(global_scaling_weight)?
290            };
291
292            let (b, s, h) = input.dims3()?;
293            let input = input.reshape((b * s, h))?;
294            let out = adapter_a.broadcast_matmul(&input.t()?)?;
295            let out = adapter_b.broadcast_matmul(&out)?;
296            let o_h = out.dims()[1];
297            let out = out.reshape((n_adapters, b, s, o_h))?;
298            let out = out.sum(0)?;
299            out + result
300        }
301    }
302    fn is_lora(&self) -> bool {
303        !self.adapters.is_empty()
304    }
305}