1use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc};
2
3use candle_core::{bail, DType, Module, Result, Tensor};
4use candle_nn::Linear;
5use either::Either;
6use mistralrs_quant::{QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear};
7
8use crate::layers::MatMul;
9
10use super::{
11 apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, AdapterSwapper,
12 LinearLayerLike, LoraConfig, LoraLinearConfig, Merge,
13};
14
15pub struct LoraLinear {
16 old: Arc<dyn QuantMethod>,
17 a_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
18 b_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
19 scale_adapters: Vec<f64>,
20 layer_n: usize,
21 merged: bool,
22 adapters: HashMap<String, Adapter>,
23}
24
25impl LoraLinear {
26 pub fn new(
27 old: &dyn LinearLayerLike,
28 linear_config: &LoraLinearConfig,
29 config: &[((String, String), LoraConfig)],
30 vb: &ShardedVarBuilder,
31 layer_n: usize,
32 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
33 ) -> Result<Self> {
34 let mut a_adapters = Vec::with_capacity(config.len());
35 let mut b_adapters = Vec::with_capacity(config.len());
36 let mut scale_adapters = Vec::with_capacity(config.len());
37 let a_vb = vb.pp("lora_A".to_string());
38 let b_vb = vb.pp("lora_B".to_string());
39 let mut state = None;
40 let mut all_same = true;
41 let mut adapters = HashMap::new();
42 for ((name_id, adapter_name), cfg) in config.iter() {
43 let a_pp = a_vb.pp(name_id);
44 let b_pp = b_vb.pp(name_id);
45 let adapter = make_adapter(a_pp, b_pp, cfg, linear_config)?;
46 a_adapters.push(adapter.a.clone());
47 b_adapters.push(adapter.b.clone());
48 scale_adapters.push(adapter.scale);
49 if state.is_some_and(|x| {
50 x == (
51 cfg.rank,
52 linear_config.in_features,
53 linear_config.out_features,
54 cfg.alpha,
55 cfg.dropout,
56 )
57 }) || state.is_none()
58 {
59 state = Some((
60 cfg.rank,
61 linear_config.in_features,
62 linear_config.out_features,
63 cfg.alpha,
64 cfg.dropout,
65 ));
66 } else {
67 all_same = false;
68 }
69 adapters.insert(adapter_name.clone(), adapter);
70 }
71
72 if let Some(preload_adapters) = preload_adapters {
73 all_same = false;
74 for (name, (vb, cfg)) in preload_adapters {
75 let a_vb = vb.set_prefix(a_vb.prefix());
76 let b_vb = vb.set_prefix(b_vb.prefix());
77 let adapter = make_adapter(a_vb, b_vb, cfg, linear_config)?;
78 adapters.insert(name.clone(), adapter);
79 }
80 }
81
82 if all_same {
83 let a_adapters_stack = Tensor::cat(
84 &a_adapters
85 .iter()
86 .map(|x| x.weight().unsqueeze(0))
87 .collect::<Result<Vec<_>>>()?,
88 0,
89 )?;
90 let b_adapters_stack = Tensor::cat(
91 &b_adapters
92 .iter()
93 .map(|x| x.weight().unsqueeze(0))
94 .collect::<Result<Vec<_>>>()?,
95 0,
96 )?;
97 let scale_adapters_t = Tensor::from_vec(
98 scale_adapters.clone(),
99 (scale_adapters.len(), 1, 1),
100 a_adapters_stack.device(),
101 )?
102 .to_dtype(a_adapters_stack.dtype())?;
103 let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?;
104 Ok(LoraLinear {
105 old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
106 Linear::new(old.weight().clone(), old.bias().cloned()),
107 ))?),
108 a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)),
109 b_adapters: Either::Right((b_adapters_stack, b_adapters)),
110 scale_adapters,
111 layer_n,
112 merged: false,
113 adapters,
114 })
115 } else {
116 Ok(LoraLinear {
117 old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
118 Linear::new(old.weight().clone(), old.bias().cloned()),
119 ))?),
120 a_adapters: Either::Left(a_adapters),
121 b_adapters: Either::Left(b_adapters),
122 scale_adapters,
123 layer_n,
124 merged: false,
125 adapters,
126 })
127 }
128 }
129}
130
131impl AdapterSwapper for LoraLinear {
132 fn _activate_adapters(&mut self, adapter_names: &[String]) -> Result<()> {
133 match (
134 &mut self.a_adapters,
135 &mut self.b_adapters,
136 &mut self.scale_adapters,
137 ) {
138 (Either::Left(a), Either::Left(b), s) => {
139 a.clear();
140 b.clear();
141 s.clear();
142 for adapter_name in adapter_names {
143 let Adapter {
144 a: a_w,
145 b: b_w,
146 scale,
147 } = match self.adapters.get(adapter_name) {
148 Some(a) => a,
149 None => bail!("Cannot load adapter `{adapter_name}`."),
150 };
151 a.push(a_w.clone());
152 b.push(b_w.clone());
153 s.push(*scale);
154 }
155 }
156 _ => unreachable!("Adapters should not be stacked if new ones are being activated."),
157 }
158 Ok(())
159 }
160 fn can_load(&self) -> bool {
161 true
162 }
163}
164
165impl Merge for LoraLinear {
166 fn get_delta_weight(&self, adapter: usize) -> Result<Tensor> {
167 match (&self.a_adapters, &self.b_adapters) {
168 (Either::Left(a), Either::Left(b)) | (Either::Right((_, a)), Either::Right((_, b))) => {
169 let w_a = a[adapter].weight();
170 let w_b = b[adapter].weight();
171
172 MatMul.matmul(w_b, w_a)? * self.scale_adapters[adapter]
173 }
174 _ => unreachable!("Both adapters must be Either::Left or Either::Right."),
175 }
176 }
177
178 fn merge_weights(&mut self) -> Result<()> {
179 let mut w_base_layer: Option<Tensor> = None;
180 for adapter in 0..self.scale_adapters.len() {
181 if let Some(w_base_layer) = &mut w_base_layer {
182 *w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?;
183 } else {
184 w_base_layer = Some(self.get_delta_weight(adapter)?)
185 }
186 }
187 self.old
188 .add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?;
189 self.merged = true;
190 Ok(())
191 }
192}
193
194impl LinearLayerLike for LoraLinear {
195 fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
196 &mut self.old
197 }
198 fn bias(&self) -> Option<&Tensor> {
199 unreachable!()
200 }
201 fn weight(&self) -> &Tensor {
202 unreachable!()
203 }
204 fn quantized_act_type(&self) -> Option<DType> {
205 self.old.quantized_act_type()
206 }
207 fn lora_forward(
208 &self,
209 input: &Tensor,
210 scalings: Option<Tensor>,
211 global_scaling_weight: f64,
212 is_scaling_pass: Option<f64>,
213 ) -> Result<Tensor> {
214 let mut result = self.old.forward(input)?;
215
216 if self.merged {
217 return Ok(result);
218 }
219
220 if is_scaling_pass.is_some_and(|x| x == 0.) {
221 return Ok(result);
222 }
223
224 let scalings =
225 scalings.map(|scalings| get_maybe_topk_scalings(scalings, self.layer_n).unwrap());
226 if self.a_adapters.is_left()
227 || scalings
228 .as_ref()
229 .is_some_and(|scalings| scalings.dims3().unwrap().1 != 1)
230 {
231 let a_adapters = if self.a_adapters.is_right() {
232 self.a_adapters.as_ref().unwrap_right().1.clone()
233 } else {
234 self.a_adapters.as_ref().unwrap_left().clone()
235 };
236 let b_adapters = if self.b_adapters.is_right() {
237 self.b_adapters.as_ref().unwrap_right().1.clone()
238 } else {
239 self.b_adapters.as_ref().unwrap_left().clone()
240 };
241 for (i, (adapter_a, (adapter_b, adapter_scale))) in
243 zip(a_adapters, zip(b_adapters, &self.scale_adapters)).enumerate()
244 {
245 let input_new = input.to_dtype(adapter_a.weight().dtype())?;
246 let input_new = if let Some(scalings) = &scalings {
247 apply_scalings_to_x(input_new, scalings, i)?
248 } else {
249 input_new
250 };
251
252 let res = adapter_b
253 .forward(&adapter_a.forward(&input_new)?)?
254 .mul(*adapter_scale)?
255 .mul(global_scaling_weight)?;
256 result = (result + res)?;
257 }
258 Ok(result)
259 } else {
260 let adapter_a = &self.a_adapters.as_ref().unwrap_right().0;
261 let adapter_b = &self.b_adapters.as_ref().unwrap_right().0;
262 let adapter_scales = &self.scale_adapters;
263 let n_adapters = adapter_scales.len();
264 let adapter_a = if let Some(scalings) = scalings.as_ref() {
265 let scalings = scalings
266 .squeeze(0)?
267 .squeeze(0)?
268 .unsqueeze(1)?
269 .unsqueeze(1)?;
270 adapter_a
271 .broadcast_mul(&scalings)?
272 .mul(global_scaling_weight)?
273 } else {
274 adapter_a.clone().mul(global_scaling_weight)?
275 };
276
277 let (b, s, h) = input.dims3()?;
278 let input = input.reshape((b * s, h))?;
279 let out = adapter_a.broadcast_matmul(&input.t()?)?;
280 let out = adapter_b.broadcast_matmul(&out)?;
281 let o_h = out.dims()[1];
282 let out = out.reshape((n_adapters, b, s, o_h))?;
283 let out = out.sum(0)?;
284 out + result
285 }
286 }
287 fn is_lora(&self) -> bool {
288 !self.adapters.is_empty()
289 }
290}