1use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc};
2
3use candle_core::{DType, Module, Result, Tensor};
4use candle_nn::Linear;
5use either::Either;
6use mistralrs_quant::{QuantMethod, QuantMethodConfig, ShardedVarBuilder, UnquantLinear};
7
8use crate::layers::MatMul;
9
10use super::{
11 apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, LinearLayerLike,
12 LoraConfig, LoraLinearConfig, Merge,
13};
14
15pub struct LoraLinear {
16 old: Arc<dyn QuantMethod>,
17 a_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
18 b_adapters: Either<Vec<Linear>, (Tensor, Vec<Linear>)>,
19 scale_adapters: Vec<f64>,
20 layer_n: usize,
21 merged: bool,
22 adapters: HashMap<String, Adapter>,
23}
24
25impl LoraLinear {
26 pub fn new(
27 old: &dyn LinearLayerLike,
28 linear_config: &LoraLinearConfig,
29 config: &[((String, String), LoraConfig)],
30 vb: &ShardedVarBuilder,
31 layer_n: usize,
32 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
33 ) -> Result<Self> {
34 let mut a_adapters = Vec::with_capacity(config.len());
35 let mut b_adapters = Vec::with_capacity(config.len());
36 let mut scale_adapters = Vec::with_capacity(config.len());
37 let a_vb = vb.pp("lora_A".to_string());
38 let b_vb = vb.pp("lora_B".to_string());
39 let mut state = None;
40 let mut all_same = true;
41 let mut adapters = HashMap::new();
42 for ((name_id, adapter_name), cfg) in config.iter() {
43 let a_pp = a_vb.pp(name_id);
44 let b_pp = b_vb.pp(name_id);
45 let adapter = make_adapter(a_pp, b_pp, cfg, linear_config)?;
46 a_adapters.push(adapter.a.clone());
47 b_adapters.push(adapter.b.clone());
48 scale_adapters.push(adapter.scale);
49 if state.is_some_and(|x| {
50 x == (
51 cfg.rank,
52 linear_config.in_features,
53 linear_config.out_features,
54 cfg.alpha,
55 cfg.dropout,
56 )
57 }) || state.is_none()
58 {
59 state = Some((
60 cfg.rank,
61 linear_config.in_features,
62 linear_config.out_features,
63 cfg.alpha,
64 cfg.dropout,
65 ));
66 } else {
67 all_same = false;
68 }
69 adapters.insert(adapter_name.clone(), adapter);
70 }
71
72 if let Some(preload_adapters) = preload_adapters {
73 all_same = false;
74 for (name, (vb, cfg)) in preload_adapters {
75 let a_vb = vb.set_prefix(a_vb.prefix());
76 let b_vb = vb.set_prefix(b_vb.prefix());
77 let adapter = make_adapter(a_vb, b_vb, cfg, linear_config)?;
78 adapters.insert(name.clone(), adapter);
79 }
80 }
81
82 if all_same {
83 let a_adapters_stack = Tensor::cat(
84 &a_adapters
85 .iter()
86 .map(|x| x.weight().unsqueeze(0))
87 .collect::<Result<Vec<_>>>()?,
88 0,
89 )?;
90 let b_adapters_stack = Tensor::cat(
91 &b_adapters
92 .iter()
93 .map(|x| x.weight().unsqueeze(0))
94 .collect::<Result<Vec<_>>>()?,
95 0,
96 )?;
97 let scale_adapters_t = Tensor::from_vec(
98 scale_adapters.clone(),
99 (scale_adapters.len(), 1, 1),
100 a_adapters_stack.device(),
101 )?
102 .to_dtype(a_adapters_stack.dtype())?;
103 let a_adapters_stack = a_adapters_stack.broadcast_mul(&scale_adapters_t)?;
104 Ok(LoraLinear {
105 old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
106 Linear::new(old.weight().clone(), old.bias().cloned()),
107 ))?),
108 a_adapters: Either::Right((a_adapters_stack.clone(), a_adapters)),
109 b_adapters: Either::Right((b_adapters_stack, b_adapters)),
110 scale_adapters,
111 layer_n,
112 merged: false,
113 adapters,
114 })
115 } else {
116 Ok(LoraLinear {
117 old: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
118 Linear::new(old.weight().clone(), old.bias().cloned()),
119 ))?),
120 a_adapters: Either::Left(a_adapters),
121 b_adapters: Either::Left(b_adapters),
122 scale_adapters,
123 layer_n,
124 merged: false,
125 adapters,
126 })
127 }
128 }
129}
130
131impl Merge for LoraLinear {
132 fn get_delta_weight(&self, adapter: usize) -> Result<Tensor> {
133 match (&self.a_adapters, &self.b_adapters) {
134 (Either::Left(a), Either::Left(b)) | (Either::Right((_, a)), Either::Right((_, b))) => {
135 let w_a = a[adapter].weight();
136 let w_b = b[adapter].weight();
137
138 MatMul.matmul(w_b, w_a)? * self.scale_adapters[adapter]
139 }
140 _ => unreachable!("Both adapters must be Either::Left or Either::Right."),
141 }
142 }
143
144 fn merge_weights(&mut self) -> Result<()> {
145 let mut w_base_layer: Option<Tensor> = None;
146 for adapter in 0..self.scale_adapters.len() {
147 if let Some(w_base_layer) = &mut w_base_layer {
148 *w_base_layer = (&*w_base_layer + &self.get_delta_weight(adapter)?)?;
149 } else {
150 w_base_layer = Some(self.get_delta_weight(adapter)?)
151 }
152 }
153 self.old
154 .add_delta_w(w_base_layer.as_ref().expect("Found no adapters to merge."))?;
155 self.merged = true;
156 Ok(())
157 }
158}
159
160impl LinearLayerLike for LoraLinear {
161 fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
162 &mut self.old
163 }
164 fn bias(&self) -> Option<&Tensor> {
165 unreachable!()
166 }
167 fn weight(&self) -> &Tensor {
168 unreachable!()
169 }
170 fn quantized_act_type(&self) -> Option<DType> {
171 self.old.quantized_act_type()
172 }
173 fn lora_forward(
174 &self,
175 input: &Tensor,
176 scalings: Option<Tensor>,
177 global_scaling_weight: f64,
178 is_scaling_pass: Option<f64>,
179 ) -> Result<Tensor> {
180 let mut result = self.old.forward(input)?;
181
182 if self.merged {
183 return Ok(result);
184 }
185
186 if is_scaling_pass.is_some_and(|x| x == 0.) {
187 return Ok(result);
188 }
189
190 let scalings =
191 scalings.map(|scalings| get_maybe_topk_scalings(scalings, self.layer_n).unwrap());
192 if self.a_adapters.is_left()
193 || scalings
194 .as_ref()
195 .is_some_and(|scalings| scalings.dims3().unwrap().1 != 1)
196 {
197 let a_adapters = if self.a_adapters.is_right() {
198 self.a_adapters.as_ref().unwrap_right().1.clone()
199 } else {
200 self.a_adapters.as_ref().unwrap_left().clone()
201 };
202 let b_adapters = if self.b_adapters.is_right() {
203 self.b_adapters.as_ref().unwrap_right().1.clone()
204 } else {
205 self.b_adapters.as_ref().unwrap_left().clone()
206 };
207 for (i, (adapter_a, (adapter_b, adapter_scale))) in
209 zip(a_adapters, zip(b_adapters, &self.scale_adapters)).enumerate()
210 {
211 let input_new = input.to_dtype(adapter_a.weight().dtype())?;
212 let input_new = if let Some(scalings) = &scalings {
213 apply_scalings_to_x(input_new, scalings, i)?
214 } else {
215 input_new
216 };
217
218 let res = adapter_b
219 .forward(&adapter_a.forward(&input_new)?)?
220 .mul(*adapter_scale)?
221 .mul(global_scaling_weight)?;
222 result = (result + res)?;
223 }
224 Ok(result)
225 } else {
226 let adapter_a = &self.a_adapters.as_ref().unwrap_right().0;
227 let adapter_b = &self.b_adapters.as_ref().unwrap_right().0;
228 let adapter_scales = &self.scale_adapters;
229 let n_adapters = adapter_scales.len();
230 let adapter_a = if let Some(scalings) = scalings.as_ref() {
231 let scalings = scalings
232 .squeeze(0)?
233 .squeeze(0)?
234 .unsqueeze(1)?
235 .unsqueeze(1)?;
236 adapter_a
237 .broadcast_mul(&scalings)?
238 .mul(global_scaling_weight)?
239 } else {
240 adapter_a.clone().mul(global_scaling_weight)?
241 };
242
243 let (b, s, h) = input.dims3()?;
244 let input = input.reshape((b * s, h))?;
245 let out = adapter_a.broadcast_matmul(&input.t()?)?;
246 let out = adapter_b.broadcast_matmul(&out)?;
247 let o_h = out.dims()[1];
248 let out = out.reshape((n_adapters, b, s, o_h))?;
249 let out = out.sum(0)?;
250 out + result
251 }
252 }
253 fn is_lora(&self) -> bool {
254 !self.adapters.is_empty()
255 }
256}