1use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc};
2
3use candle_core::{bail, quantized::QMatMul, DType, Module, Result, Tensor};
4use candle_nn::Linear;
5use either::Either;
6use mistralrs_quant::{
7 GgufMatMul, QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear,
8};
9
10use crate::layers::MatMul;
11
12use super::{
13 apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, AdapterSwapper,
14 LinearLayerLike, LoraConfig, LoraLinearConfig, Merge, Ordering,
15};
16
17#[derive(Debug)]
18pub struct QLoraLinear {
19 old: Arc<dyn QuantMethod>,
20 a_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
21 b_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
22 scale_adapters: Vec<f64>,
23 layer_n: usize,
24 merged: bool,
25 adapters: HashMap<String, Adapter>,
26 linear_config: Option<LoraLinearConfig>,
27}
28
29impl QLoraLinear {
31 #[allow(clippy::too_many_arguments)]
32 pub fn new(
33 old: QMatMul,
34 linear_config: &LoraLinearConfig,
35 config: &[((String, String), LoraConfig)],
36 vb: &ShardedVarBuilder,
37 ordering: &Ordering,
38 prefix: String,
39 count: &mut usize,
40 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
41 ) -> Result<Self> {
42 let target_modules = &config.first().map(|c| &c.1.target_modules);
43 for (_, cfg) in config {
44 if target_modules
45 .as_ref()
46 .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
47 {
48 candle_core::bail!("Expected all target modules to be the same.");
49 }
50 }
51
52 let old: Arc<dyn QuantMethod> = match old {
53 QMatMul::QTensor(q) => Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
54 q_weight: q,
55 b: None,
56 })?),
57 QMatMul::TensorF16(t) | QMatMul::Tensor(t) => Arc::new(UnquantLinear::new(
58 QuantMethodConfig::Unquantized(Linear::new(t, None)),
59 )?),
60 };
61
62 let module = prefix.split('.').next_back().unwrap();
63 if target_modules.is_some_and(|target_modules| !target_modules.contains(module)) {
64 return Ok(Self {
65 old,
66 a_adapters: Either::Left(vec![]),
67 b_adapters: Either::Left(vec![]),
68 scale_adapters: vec![],
69 layer_n: usize::MAX,
70 merged: false,
71 adapters: HashMap::default(),
72 linear_config: None,
73 });
74 }
75
76 *count += 1;
77
78 let mut a_adapters = Vec::with_capacity(config.len());
79 let mut b_adapters = Vec::with_capacity(config.len());
80 let mut scale_adapters = Vec::with_capacity(config.len());
81 let vb = vb.pp(prefix.clone());
82 let a_vb = vb.pp("lora_A".to_string());
83 let b_vb = vb.pp("lora_B".to_string());
84 let mut state = None;
85 let mut all_same = true;
86 let mut adapters = HashMap::new();
87 for ((name_id, adapter_name), cfg) in config.iter() {
88 let a_pp = a_vb.pp(name_id);
89 let b_pp = b_vb.pp(name_id);
90 let adapter = make_adapter(a_pp, b_pp, cfg, linear_config)?;
91 a_adapters.push(adapter.a.clone());
92 b_adapters.push(adapter.b.clone());
93 scale_adapters.push(adapter.scale);
94 if state.is_some_and(|x| {
95 x == (
96 cfg.rank,
97 linear_config.in_features,
98 linear_config.out_features,
99 cfg.alpha,
100 cfg.dropout,
101 )
102 }) || state.is_none()
103 {
104 state = Some((
105 cfg.rank,
106 linear_config.in_features,
107 linear_config.out_features,
108 cfg.alpha,
109 cfg.dropout,
110 ));
111 } else {
112 all_same = false;
113 }
114 adapters.insert(adapter_name.clone(), adapter);
115 }
116
117 if let Some(preload_adapters) = preload_adapters {
118 all_same = false;
119 for (name, (vb, cfg)) in preload_adapters {
120 let a_vb = vb.set_prefix(a_vb.prefix());
121 let b_vb = vb.set_prefix(b_vb.prefix());
122 let adapter = make_adapter(a_vb, b_vb, cfg, linear_config)?;
123 adapters.insert(name.clone(), adapter);
124 }
125 }
126
127 let layer = if let Some(ref layers) = ordering.layers {
128 *layers.get(&prefix).unwrap()
129 } else {
130 0
131 };
132
133 if all_same {
134 let a_adapters_stack = Tensor::cat(
135 &a_adapters
136 .iter()
137 .map(|x| x.weight().unsqueeze(0))
138 .collect::<Result<Vec<_>>>()?,
139 0,
140 )?;
141 let b_adapters_stack = Tensor::cat(
142 &b_adapters
143 .iter()
144 .map(|x| x.weight().unsqueeze(0))
145 .collect::<Result<Vec<_>>>()?,
146 0,
147 )?;
148 let scale_adapters_t = Tensor::from_vec(
149 scale_adapters.clone(),
150 (scale_adapters.len(), 1, 1),
151 a_adapters_stack.device(),
152 )?
153 .to_dtype(a_adapters_stack.dtype())?;
154 let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?;
155 Ok(QLoraLinear {
156 old,
157 a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)),
158 b_adapters: Either::Right((b_adapters_stack.clone(), b_adapters)),
159 scale_adapters,
160 layer_n: layer,
161 merged: false,
162 adapters,
163 linear_config: Some(linear_config.clone()),
164 })
165 } else {
166 Ok(QLoraLinear {
167 old,
168 a_adapters: Either::Left(a_adapters),
169 b_adapters: Either::Left(b_adapters),
170 scale_adapters,
171 layer_n: layer,
172 merged: false,
173 adapters,
174 linear_config: Some(linear_config.clone()),
175 })
176 }
177 }
178}
179
180impl AdapterSwapper for QLoraLinear {
181 fn _activate_adapters(&mut self, adapter_names: &[String]) -> Result<()> {
182 match (
183 &mut self.a_adapters,
184 &mut self.b_adapters,
185 &mut self.scale_adapters,
186 ) {
187 (Either::Left(a), Either::Left(b), s) => {
188 a.clear();
189 b.clear();
190 s.clear();
191 for adapter_name in adapter_names {
192 let Adapter {
193 a: a_w,
194 b: b_w,
195 scale,
196 } = match self.adapters.get(adapter_name) {
197 Some(a) => a,
198 None => bail!("Cannot load adapter `{adapter_name}`."),
199 };
200 a.push(a_w.clone());
201 b.push(b_w.clone());
202 s.push(*scale);
203 }
204 }
205 _ => unreachable!("Adapters should not be stacked if new ones are being activated."),
206 }
207 Ok(())
208 }
209 fn can_load(&self) -> bool {
210 self.linear_config.is_some()
211 }
212}
213
214impl Merge for QLoraLinear {
215 fn get_delta_weight(&self, adapter: usize) -> Result<Tensor> {
216 match (&self.a_adapters, &self.b_adapters) {
217 (Either::Left(a), Either::Left(b)) | (Either::Right((_, a)), Either::Right((_, b))) => {
218 let w_a = a[adapter].weight();
219 let w_b = b[adapter].weight();
220
221 MatMul.matmul(w_b, w_a)? * self.scale_adapters[adapter]
222 }
223 _ => unreachable!("Both adapters must be Either::Left or Either::Right."),
224 }
225 }
226
227 fn merge_weights(&mut self) -> Result<()> {
228 let mut w_base_layer: Option<Tensor> = None;
229 for adapter in 0..self.scale_adapters.len() {
230 if let Some(w_base_layer) = &mut w_base_layer {
231 *w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?;
232 } else {
233 w_base_layer = Some(self.get_delta_weight(adapter)?)
234 }
235 }
236 self.old
237 .add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?;
238 self.merged = true;
239 Ok(())
240 }
241}
242
243impl LinearLayerLike for QLoraLinear {
244 fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
245 &mut self.old
246 }
247 fn bias(&self) -> Option<&Tensor> {
248 None
249 }
250 fn weight(&self) -> &Tensor {
251 unimplemented!()
252 }
253 fn quantized_act_type(&self) -> Option<DType> {
254 Some(DType::F32)
255 }
256 fn lora_forward(
257 &self,
258 input: &Tensor,
259 scalings: Option<Tensor>,
260 global_scaling_weight: f64,
261 is_scaling_pass: Option<f64>,
262 ) -> Result<Tensor> {
263 let mut result = self.old.forward(input)?;
265 if self.merged {
266 return Ok(result);
267 }
268
269 if self
270 .a_adapters
271 .as_ref()
272 .left()
273 .is_some_and(|x| x.is_empty())
274 || (is_scaling_pass.is_some_and(|x| x == 0.))
275 {
276 return Ok(result);
277 }
278
279 let scalings =
280 scalings.map(|scalings| get_maybe_topk_scalings(scalings, self.layer_n).unwrap());
281 if self.a_adapters.is_left()
282 || scalings
283 .as_ref()
284 .is_some_and(|scalings| scalings.dims3().unwrap().1 != 1)
285 {
286 let a_adapters = if self.a_adapters.is_right() {
287 self.a_adapters.as_ref().unwrap_right().1.clone()
288 } else {
289 self.a_adapters.as_ref().unwrap_left().clone()
290 };
291 let b_adapters = if self.b_adapters.is_right() {
292 self.b_adapters.as_ref().unwrap_right().1.clone()
293 } else {
294 self.b_adapters.as_ref().unwrap_left().clone()
295 };
296 for (i, (adapter_a, (adapter_b, adapter_scale))) in
297 zip(a_adapters, zip(b_adapters, &self.scale_adapters)).enumerate()
298 {
299 let input_new = if let Some(scalings) = &scalings {
300 apply_scalings_to_x(input.clone(), scalings, i)?
301 } else {
302 input.clone()
303 };
304
305 let res = adapter_b
306 .forward(&adapter_a.forward(&input_new)?)?
307 .mul(*adapter_scale)?
308 .mul(global_scaling_weight)?;
309 result = (result + res)?;
310 }
311 Ok(result)
312 } else {
313 let adapter_a = &self.a_adapters.as_ref().unwrap_right().0;
314 let adapter_b = &self.b_adapters.as_ref().unwrap_right().0;
315 let adapter_scales = &self.scale_adapters;
316 let n_adapters = adapter_scales.len();
317 let adapter_a = if let Some(scalings) = scalings.as_ref() {
318 let scalings = scalings
319 .squeeze(0)?
320 .squeeze(0)?
321 .unsqueeze(1)?
322 .unsqueeze(1)?;
323 adapter_a
324 .broadcast_mul(&scalings)?
325 .mul(global_scaling_weight)?
326 } else {
327 adapter_a.clone().mul(global_scaling_weight)?
328 };
329
330 let (b, s, h) = input.dims3()?;
331 let input = input.reshape((b * s, h))?;
332 let out = adapter_a.broadcast_matmul(&input.t()?)?;
333 let out = adapter_b.broadcast_matmul(&out)?;
334 let o_h = out.dims()[1];
335 let out = out.reshape((n_adapters, b, s, o_h))?;
336 let out = out.sum(0)?;
337 out + result
338 }
339 }
340 fn is_lora(&self) -> bool {
341 !self.adapters.is_empty()
342 }
343}