1#![allow(clippy::cast_precision_loss)]
2
3use std::{collections::HashSet, fmt::Debug, sync::Arc};
4
5use candle_core::{quantized::QTensor, DType, IndexOp, Result, Tensor, D};
6use candle_nn::{Linear, Module};
7use loralinear::LoraLinear;
8use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
9pub use qloralinear::QLoraLinear;
10use serde::Deserialize;
11
12mod loralinear;
13mod qloralinear;
14
15use std::collections::HashMap;
16
17use crate::layers;
18
19#[derive(Clone, Debug, Deserialize)]
20pub struct PreloadAdapter {
21 pub name: String,
22 pub adapter_model_id: String,
23}
24
25#[derive(Clone, Debug, Deserialize)]
26pub struct Ordering {
28 #[serde(rename = "order")]
29 pub adapters: Option<Vec<String>>,
30 pub layers: Option<HashMap<String, usize>>,
31 pub base_model_id: String,
32 pub preload_adapters: Option<Vec<PreloadAdapter>>,
33}
34
35#[derive(Clone, Debug)]
36pub struct LoraLinearConfig {
38 in_features: usize,
39 out_features: usize,
40}
41
42impl LoraLinearConfig {
43 pub fn new(in_features: usize, out_features: usize) -> Self {
44 LoraLinearConfig {
45 in_features,
46 out_features,
47 }
48 }
49}
50
51#[derive(Clone, Debug, Deserialize)]
52pub struct LoraConfig {
53 #[serde(rename = "r")]
54 rank: usize,
55 #[serde(rename = "lora_alpha")]
56 alpha: f64,
57 #[serde(rename = "lora_dropout")]
58 dropout: Option<f32>,
59 target_modules: HashSet<String>,
60}
61
62fn apply_scalings_to_x(x: Tensor, scalings_layer: &Tensor, adapter: usize) -> Result<Tensor> {
63 let scalings = scalings_layer.i((.., .., adapter))?.unsqueeze(D::Minus1)?;
64 let res = x.broadcast_mul(&scalings)?;
65 Ok(res)
66}
67
68#[derive(Debug)]
69struct Adapter {
70 a: Linear,
71 b: Linear,
72 scale: f64,
73}
74
75fn make_adapter(
76 a_vb: ShardedVarBuilder,
77 b_vb: ShardedVarBuilder,
78 cfg: &LoraConfig,
79 linear_cfg: &LoraLinearConfig,
80) -> Result<Adapter> {
81 assert!(a_vb.contains_tensor("weight"));
82 let a = a_vb.get((cfg.rank, linear_cfg.in_features), "weight")?;
83 assert!(b_vb.contains_tensor("weight"));
84 let b = b_vb.get((linear_cfg.out_features, cfg.rank), "weight")?;
85 let a = Linear::new(a, None);
86 let b = Linear::new(b, None);
87 let scale = if cfg.rank > 0 {
88 cfg.alpha / cfg.rank as f64
89 } else {
90 1.0
91 };
92 Ok(Adapter { a, b, scale })
93}
94
95pub trait LinearLayerLike: Merge {
97 fn quantized_act_type(&self) -> Option<DType>;
98 fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod>;
99 fn is_lora(&self) -> bool;
100 fn weight(&self) -> &Tensor;
101 fn bias(&self) -> Option<&Tensor>;
102 fn lora_forward(
103 &self,
104 x: &Tensor,
105 scalings_layer: Option<Tensor>,
106 global_scaling_weight: f64,
107 is_scaling_pass: Option<f64>,
108 ) -> Result<Tensor>;
109}
110
111pub trait Merge {
112 fn get_delta_weight(&self, adapter: usize) -> Result<Tensor>;
114 fn merge_weights(&mut self) -> Result<()>;
116}
117
118impl Merge for Linear {
119 fn merge_weights(&mut self) -> Result<()> {
120 Ok(())
121 }
122 fn get_delta_weight(&self, _adapter: usize) -> Result<Tensor> {
123 unreachable!()
124 }
125}
126
127impl LinearLayerLike for Linear {
128 fn bias(&self) -> Option<&Tensor> {
129 self.bias()
130 }
131 fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
132 unimplemented!("Linear layer has no reasonable quant inner!")
133 }
134 fn weight(&self) -> &Tensor {
135 self.weight()
136 }
137 fn lora_forward(
138 &self,
139 x: &Tensor,
140 _scalings_layer: Option<Tensor>,
141 _global_scaling_weight: f64,
142 _is_scaling_pass: Option<f64>,
143 ) -> Result<Tensor> {
144 self.forward(x)
145 }
146 fn quantized_act_type(&self) -> Option<DType> {
147 None
148 }
149 fn is_lora(&self) -> bool {
150 false
151 }
152}
153
154#[allow(clippy::too_many_arguments)]
155pub fn linear(
156 d1: usize,
157 d2: usize,
158 base_vb: ShardedVarBuilder,
159 vb: ShardedVarBuilder,
160 lora_config: &[((String, String), LoraConfig)],
161 count: &mut usize,
162 ord: &Ordering,
163 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
164) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
165 let prefix = vb.prefix();
166 let module = prefix.split('.').next_back().unwrap();
167
168 let linear_config = LoraLinearConfig::new(d1, d2);
169 let inner = layers::linear(d1, d2, base_vb.clone())?;
170
171 let target_modules = &lora_config.first().map(|c| &c.1.target_modules);
172 for (_, cfg) in lora_config {
173 if target_modules
174 .as_ref()
175 .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
176 {
177 candle_core::bail!("Expected all target modules to be the same.");
178 }
179 }
180
181 if !target_modules
182 .as_ref()
183 .is_some_and(|target_modules| target_modules.contains(module))
184 {
185 return Ok(Arc::new(inner));
186 }
187 let name = prefix.split("lora_A").last().unwrap();
188 let layer = if let Some(ref layers) = ord.layers {
189 *layers.get(name).unwrap()
190 } else {
191 0
192 };
193
194 let lorainner = LoraLinear::new(
195 &inner,
196 &linear_config,
197 lora_config,
198 &vb,
199 layer,
200 preload_adapters,
201 )?;
202 *count += 1;
203 Ok(Arc::new(lorainner))
204}
205
206#[allow(clippy::too_many_arguments)]
207pub fn linear_no_bias(
208 d1: usize,
209 d2: usize,
210 base_vb: ShardedVarBuilder,
211 vb: ShardedVarBuilder,
212 lora_config: &[((String, String), LoraConfig)],
213 count: &mut usize,
214 ord: &Ordering,
215 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
216) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
217 let prefix = vb.prefix();
218 let module = prefix.split('.').next_back().unwrap();
219
220 let linear_config = LoraLinearConfig::new(d1, d2);
221 let inner = layers::linear_no_bias(d1, d2, base_vb.clone())?;
222
223 let target_modules = &lora_config.first().map(|c| &c.1.target_modules);
224 for (_, cfg) in lora_config {
225 if target_modules
226 .as_ref()
227 .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
228 {
229 candle_core::bail!("Expected all target modules to be the same.");
230 }
231 }
232
233 if !target_modules
234 .as_ref()
235 .is_some_and(|target_modules| target_modules.contains(module))
236 {
237 return Ok(Arc::new(inner));
238 }
239 let name = prefix.split("lora_A").last().unwrap();
240 let layer = if let Some(ref layers) = ord.layers {
241 *layers.get(name).unwrap()
242 } else {
243 0
244 };
245
246 let lorainner = LoraLinear::new(
247 &inner,
248 &linear_config,
249 lora_config,
250 &vb,
251 layer,
252 preload_adapters,
253 )?;
254 *count += 1;
255 Ok(Arc::new(lorainner))
256}
257
258fn get_maybe_topk_scalings(scalings: Tensor, layer: usize) -> Result<Tensor> {
259 scalings.i((.., .., layer, ..))
260}
261
262#[allow(clippy::too_many_arguments)]
263pub fn linear_b(
264 in_dim: usize,
265 out_dim: usize,
266 bias: bool,
267 base_vb: ShardedVarBuilder,
268 vb: ShardedVarBuilder,
269 lora_config: &[((String, String), LoraConfig)],
270 count: &mut usize,
271 ord: &Ordering,
272 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
273) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
274 if bias {
275 linear(
276 in_dim,
277 out_dim,
278 base_vb,
279 vb,
280 lora_config,
281 count,
282 ord,
283 preload_adapters,
284 )
285 } else {
286 linear_no_bias(
287 in_dim,
288 out_dim,
289 base_vb,
290 vb,
291 lora_config,
292 count,
293 ord,
294 preload_adapters,
295 )
296 }
297}
298
299pub fn get_lora_cfg(tensor: &QTensor) -> LoraLinearConfig {
300 LoraLinearConfig::new(tensor.shape().dims()[1], tensor.shape().dims()[0])
301}