1use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc};
2
3use candle_core::{quantized::QMatMul, DType, Module, Result, Tensor};
4use candle_nn::Linear;
5use either::Either;
6use mistralrs_quant::{
7 GgufMatMul, QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear,
8};
9
10use crate::layers::MatMul;
11
12use super::{
13 apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, LinearLayerLike,
14 LoraConfig, LoraLinearConfig, Merge, Ordering,
15};
16
17#[derive(Debug)]
18pub struct QLoraLinear {
19 old: Arc<dyn QuantMethod>,
20 a_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
21 b_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
22 scale_adapters: Vec<f64>,
23 layer_n: usize,
24 merged: bool,
25 adapters: HashMap<String, Adapter>,
26}
27
28impl QLoraLinear {
30 #[allow(clippy::too_many_arguments)]
31 pub fn new(
32 old: QMatMul,
33 linear_config: &LoraLinearConfig,
34 config: &[((String, String), LoraConfig)],
35 vb: &ShardedVarBuilder,
36 ordering: &Ordering,
37 prefix: String,
38 count: &mut usize,
39 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
40 ) -> Result<Self> {
41 let target_modules = &config.first().map(|c| &c.1.target_modules);
42 for (_, cfg) in config {
43 if target_modules
44 .as_ref()
45 .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
46 {
47 candle_core::bail!("Expected all target modules to be the same.");
48 }
49 }
50
51 let old: Arc<dyn QuantMethod> = match old {
52 QMatMul::QTensor(q) => Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
53 q_weight: q,
54 b: None,
55 })?),
56 QMatMul::TensorF16(t) | QMatMul::Tensor(t) => Arc::new(UnquantLinear::new(
57 QuantMethodConfig::Unquantized(Linear::new(t, None)),
58 )?),
59 };
60
61 let module = prefix.split('.').next_back().unwrap();
62 if target_modules.is_some_and(|target_modules| !target_modules.contains(module)) {
63 return Ok(Self {
64 old,
65 a_adapters: Either::Left(vec![]),
66 b_adapters: Either::Left(vec![]),
67 scale_adapters: vec![],
68 layer_n: usize::MAX,
69 merged: false,
70 adapters: HashMap::default(),
71 });
72 }
73
74 *count += 1;
75
76 let mut a_adapters = Vec::with_capacity(config.len());
77 let mut b_adapters = Vec::with_capacity(config.len());
78 let mut scale_adapters = Vec::with_capacity(config.len());
79 let vb = vb.pp(prefix.clone());
80 let a_vb = vb.pp("lora_A".to_string());
81 let b_vb = vb.pp("lora_B".to_string());
82 let mut state = None;
83 let mut all_same = true;
84 let mut adapters = HashMap::new();
85 for ((name_id, adapter_name), cfg) in config.iter() {
86 let a_pp = a_vb.pp(name_id);
87 let b_pp = b_vb.pp(name_id);
88 let adapter = make_adapter(a_pp, b_pp, cfg, linear_config)?;
89 a_adapters.push(adapter.a.clone());
90 b_adapters.push(adapter.b.clone());
91 scale_adapters.push(adapter.scale);
92 if state.is_some_and(|x| {
93 x == (
94 cfg.rank,
95 linear_config.in_features,
96 linear_config.out_features,
97 cfg.alpha,
98 cfg.dropout,
99 )
100 }) || state.is_none()
101 {
102 state = Some((
103 cfg.rank,
104 linear_config.in_features,
105 linear_config.out_features,
106 cfg.alpha,
107 cfg.dropout,
108 ));
109 } else {
110 all_same = false;
111 }
112 adapters.insert(adapter_name.clone(), adapter);
113 }
114
115 if let Some(preload_adapters) = preload_adapters {
116 all_same = false;
117 for (name, (vb, cfg)) in preload_adapters {
118 let a_vb = vb.set_prefix(a_vb.prefix());
119 let b_vb = vb.set_prefix(b_vb.prefix());
120 let adapter = make_adapter(a_vb, b_vb, cfg, linear_config)?;
121 adapters.insert(name.clone(), adapter);
122 }
123 }
124
125 let layer = if let Some(ref layers) = ordering.layers {
126 *layers.get(&prefix).unwrap()
127 } else {
128 0
129 };
130
131 if all_same {
132 let a_adapters_stack = Tensor::cat(
133 &a_adapters
134 .iter()
135 .map(|x| x.weight().unsqueeze(0))
136 .collect::<Result<Vec<_>>>()?,
137 0,
138 )?;
139 let b_adapters_stack = Tensor::cat(
140 &b_adapters
141 .iter()
142 .map(|x| x.weight().unsqueeze(0))
143 .collect::<Result<Vec<_>>>()?,
144 0,
145 )?;
146 let scale_adapters_t = Tensor::from_vec(
147 scale_adapters.clone(),
148 (scale_adapters.len(), 1, 1),
149 a_adapters_stack.device(),
150 )?
151 .to_dtype(a_adapters_stack.dtype())?;
152 let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?;
153 Ok(QLoraLinear {
154 old,
155 a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)),
156 b_adapters: Either::Right((b_adapters_stack.clone(), b_adapters)),
157 scale_adapters,
158 layer_n: layer,
159 merged: false,
160 adapters,
161 })
162 } else {
163 Ok(QLoraLinear {
164 old,
165 a_adapters: Either::Left(a_adapters),
166 b_adapters: Either::Left(b_adapters),
167 scale_adapters,
168 layer_n: layer,
169 merged: false,
170 adapters,
171 })
172 }
173 }
174}
175
176impl Merge for QLoraLinear {
177 fn get_delta_weight(&self, adapter: usize) -> Result<Tensor> {
178 match (&self.a_adapters, &self.b_adapters) {
179 (Either::Left(a), Either::Left(b)) | (Either::Right((_, a)), Either::Right((_, b))) => {
180 let w_a = a[adapter].weight();
181 let w_b = b[adapter].weight();
182
183 MatMul.matmul(w_b, w_a)? * self.scale_adapters[adapter]
184 }
185 _ => unreachable!("Both adapters must be Either::Left or Either::Right."),
186 }
187 }
188
189 fn merge_weights(&mut self) -> Result<()> {
190 let mut w_base_layer: Option<Tensor> = None;
191 for adapter in 0..self.scale_adapters.len() {
192 if let Some(w_base_layer) = &mut w_base_layer {
193 *w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?;
194 } else {
195 w_base_layer = Some(self.get_delta_weight(adapter)?)
196 }
197 }
198 self.old
199 .add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?;
200 self.merged = true;
201 Ok(())
202 }
203}
204
205impl LinearLayerLike for QLoraLinear {
206 fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
207 &mut self.old
208 }
209 fn bias(&self) -> Option<&Tensor> {
210 None
211 }
212 fn weight(&self) -> &Tensor {
213 unimplemented!()
214 }
215 fn quantized_act_type(&self) -> Option<DType> {
216 Some(DType::F32)
217 }
218 fn lora_forward(
219 &self,
220 input: &Tensor,
221 scalings: Option<Tensor>,
222 global_scaling_weight: f64,
223 is_scaling_pass: Option<f64>,
224 ) -> Result<Tensor> {
225 let mut result = self.old.forward(input)?;
227 if self.merged {
228 return Ok(result);
229 }
230
231 if self
232 .a_adapters
233 .as_ref()
234 .left()
235 .is_some_and(|x| x.is_empty())
236 || (is_scaling_pass.is_some_and(|x| x == 0.))
237 {
238 return Ok(result);
239 }
240
241 let scalings =
242 scalings.map(|scalings| get_maybe_topk_scalings(scalings, self.layer_n).unwrap());
243 if self.a_adapters.is_left()
244 || scalings
245 .as_ref()
246 .is_some_and(|scalings| scalings.dims3().unwrap().1 != 1)
247 {
248 let a_adapters = if self.a_adapters.is_right() {
249 self.a_adapters.as_ref().unwrap_right().1.clone()
250 } else {
251 self.a_adapters.as_ref().unwrap_left().clone()
252 };
253 let b_adapters = if self.b_adapters.is_right() {
254 self.b_adapters.as_ref().unwrap_right().1.clone()
255 } else {
256 self.b_adapters.as_ref().unwrap_left().clone()
257 };
258 for (i, (adapter_a, (adapter_b, adapter_scale))) in
259 zip(a_adapters, zip(b_adapters, &self.scale_adapters)).enumerate()
260 {
261 let input_new = if let Some(scalings) = &scalings {
262 apply_scalings_to_x(input.clone(), scalings, i)?
263 } else {
264 input.clone()
265 };
266
267 let res = adapter_b
268 .forward(&adapter_a.forward(&input_new)?)?
269 .mul(*adapter_scale)?
270 .mul(global_scaling_weight)?;
271 result = (result + res)?;
272 }
273 Ok(result)
274 } else {
275 let adapter_a = &self.a_adapters.as_ref().unwrap_right().0;
276 let adapter_b = &self.b_adapters.as_ref().unwrap_right().0;
277 let adapter_scales = &self.scale_adapters;
278 let n_adapters = adapter_scales.len();
279 let adapter_a = if let Some(scalings) = scalings.as_ref() {
280 let scalings = scalings
281 .squeeze(0)?
282 .squeeze(0)?
283 .unsqueeze(1)?
284 .unsqueeze(1)?;
285 adapter_a
286 .broadcast_mul(&scalings)?
287 .mul(global_scaling_weight)?
288 } else {
289 adapter_a.clone().mul(global_scaling_weight)?
290 };
291
292 let (b, s, h) = input.dims3()?;
293 let input = input.reshape((b * s, h))?;
294 let out = adapter_a.broadcast_matmul(&input.t()?)?;
295 let out = adapter_b.broadcast_matmul(&out)?;
296 let o_h = out.dims()[1];
297 let out = out.reshape((n_adapters, b, s, o_h))?;
298 let out = out.sum(0)?;
299 out + result
300 }
301 }
302 fn is_lora(&self) -> bool {
303 !self.adapters.is_empty()
304 }
305}