mistralrs_core/lora/
mod.rs

1#![allow(clippy::cast_precision_loss)]
2
3use std::{collections::HashSet, fmt::Debug, sync::Arc};
4
5use candle_core::{quantized::QTensor, DType, IndexOp, Result, Tensor, D};
6use candle_nn::{Linear, Module};
7use loralinear::LoraLinear;
8use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
9pub use qloralinear::QLoraLinear;
10use serde::Deserialize;
11
12mod loralinear;
13mod qloralinear;
14
15use std::collections::HashMap;
16
17use crate::layers;
18
19#[derive(Clone, Debug, Deserialize)]
20pub struct PreloadAdapter {
21    pub name: String,
22    pub adapter_model_id: String,
23}
24
25#[derive(Clone, Debug, Deserialize)]
26/// Adapter model ordering information.
27pub struct Ordering {
28    #[serde(rename = "order")]
29    pub adapters: Option<Vec<String>>,
30    pub layers: Option<HashMap<String, usize>>,
31    pub base_model_id: String,
32    pub preload_adapters: Option<Vec<PreloadAdapter>>,
33}
34
35#[derive(Clone, Debug)]
36/// Configuration for LoraLinear
37pub struct LoraLinearConfig {
38    in_features: usize,
39    out_features: usize,
40}
41
42impl LoraLinearConfig {
43    pub fn new(in_features: usize, out_features: usize) -> Self {
44        LoraLinearConfig {
45            in_features,
46            out_features,
47        }
48    }
49}
50
51#[derive(Clone, Debug, Deserialize)]
52pub struct LoraConfig {
53    #[serde(rename = "r")]
54    rank: usize,
55    #[serde(rename = "lora_alpha")]
56    alpha: f64,
57    #[serde(rename = "lora_dropout")]
58    dropout: Option<f32>,
59    target_modules: HashSet<String>,
60}
61
62fn apply_scalings_to_x(x: Tensor, scalings_layer: &Tensor, adapter: usize) -> Result<Tensor> {
63    let scalings = scalings_layer.i((.., .., adapter))?.unsqueeze(D::Minus1)?;
64    let res = x.broadcast_mul(&scalings)?;
65    Ok(res)
66}
67
68#[derive(Debug)]
69struct Adapter {
70    a: Linear,
71    b: Linear,
72    scale: f64,
73}
74
75fn make_adapter(
76    a_vb: ShardedVarBuilder,
77    b_vb: ShardedVarBuilder,
78    cfg: &LoraConfig,
79    linear_cfg: &LoraLinearConfig,
80) -> Result<Adapter> {
81    assert!(a_vb.contains_tensor("weight"));
82    let a = a_vb.get((cfg.rank, linear_cfg.in_features), "weight")?;
83    assert!(b_vb.contains_tensor("weight"));
84    let b = b_vb.get((linear_cfg.out_features, cfg.rank), "weight")?;
85    let a = Linear::new(a, None);
86    let b = Linear::new(b, None);
87    let scale = if cfg.rank > 0 {
88        cfg.alpha / cfg.rank as f64
89    } else {
90        1.0
91    };
92    Ok(Adapter { a, b, scale })
93}
94
95/// Any layer that is linear-like.
96pub trait LinearLayerLike: Merge {
97    fn quantized_act_type(&self) -> Option<DType>;
98    fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod>;
99    fn is_lora(&self) -> bool;
100    fn weight(&self) -> &Tensor;
101    fn bias(&self) -> Option<&Tensor>;
102    fn lora_forward(
103        &self,
104        x: &Tensor,
105        scalings_layer: Option<Tensor>,
106        global_scaling_weight: f64,
107        is_scaling_pass: Option<f64>,
108    ) -> Result<Tensor>;
109}
110
111pub trait Merge {
112    /// Get the delta weight of the LoRA layer. This is meant to be an internal method.
113    fn get_delta_weight(&self, adapter: usize) -> Result<Tensor>;
114    /// Merge the LoRA weights.
115    fn merge_weights(&mut self) -> Result<()>;
116}
117
118impl Merge for Linear {
119    fn merge_weights(&mut self) -> Result<()> {
120        Ok(())
121    }
122    fn get_delta_weight(&self, _adapter: usize) -> Result<Tensor> {
123        unreachable!()
124    }
125}
126
127impl LinearLayerLike for Linear {
128    fn bias(&self) -> Option<&Tensor> {
129        self.bias()
130    }
131    fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
132        unimplemented!("Linear layer has no reasonable quant inner!")
133    }
134    fn weight(&self) -> &Tensor {
135        self.weight()
136    }
137    fn lora_forward(
138        &self,
139        x: &Tensor,
140        _scalings_layer: Option<Tensor>,
141        _global_scaling_weight: f64,
142        _is_scaling_pass: Option<f64>,
143    ) -> Result<Tensor> {
144        self.forward(x)
145    }
146    fn quantized_act_type(&self) -> Option<DType> {
147        None
148    }
149    fn is_lora(&self) -> bool {
150        false
151    }
152}
153
154#[allow(clippy::too_many_arguments)]
155pub fn linear(
156    d1: usize,
157    d2: usize,
158    base_vb: ShardedVarBuilder,
159    vb: ShardedVarBuilder,
160    lora_config: &[((String, String), LoraConfig)],
161    count: &mut usize,
162    ord: &Ordering,
163    preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
164) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
165    let prefix = vb.prefix();
166    let module = prefix.split('.').next_back().unwrap();
167
168    let linear_config = LoraLinearConfig::new(d1, d2);
169    let inner = layers::linear(d1, d2, base_vb.clone())?;
170
171    let target_modules = &lora_config.first().map(|c| &c.1.target_modules);
172    for (_, cfg) in lora_config {
173        if target_modules
174            .as_ref()
175            .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
176        {
177            candle_core::bail!("Expected all target modules to be the same.");
178        }
179    }
180
181    if !target_modules
182        .as_ref()
183        .is_some_and(|target_modules| target_modules.contains(module))
184    {
185        return Ok(Arc::new(inner));
186    }
187    let name = prefix.split("lora_A").last().unwrap();
188    let layer = if let Some(ref layers) = ord.layers {
189        *layers.get(name).unwrap()
190    } else {
191        0
192    };
193
194    let lorainner = LoraLinear::new(
195        &inner,
196        &linear_config,
197        lora_config,
198        &vb,
199        layer,
200        preload_adapters,
201    )?;
202    *count += 1;
203    Ok(Arc::new(lorainner))
204}
205
206#[allow(clippy::too_many_arguments)]
207pub fn linear_no_bias(
208    d1: usize,
209    d2: usize,
210    base_vb: ShardedVarBuilder,
211    vb: ShardedVarBuilder,
212    lora_config: &[((String, String), LoraConfig)],
213    count: &mut usize,
214    ord: &Ordering,
215    preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
216) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
217    let prefix = vb.prefix();
218    let module = prefix.split('.').next_back().unwrap();
219
220    let linear_config = LoraLinearConfig::new(d1, d2);
221    let inner = layers::linear_no_bias(d1, d2, base_vb.clone())?;
222
223    let target_modules = &lora_config.first().map(|c| &c.1.target_modules);
224    for (_, cfg) in lora_config {
225        if target_modules
226            .as_ref()
227            .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
228        {
229            candle_core::bail!("Expected all target modules to be the same.");
230        }
231    }
232
233    if !target_modules
234        .as_ref()
235        .is_some_and(|target_modules| target_modules.contains(module))
236    {
237        return Ok(Arc::new(inner));
238    }
239    let name = prefix.split("lora_A").last().unwrap();
240    let layer = if let Some(ref layers) = ord.layers {
241        *layers.get(name).unwrap()
242    } else {
243        0
244    };
245
246    let lorainner = LoraLinear::new(
247        &inner,
248        &linear_config,
249        lora_config,
250        &vb,
251        layer,
252        preload_adapters,
253    )?;
254    *count += 1;
255    Ok(Arc::new(lorainner))
256}
257
258fn get_maybe_topk_scalings(scalings: Tensor, layer: usize) -> Result<Tensor> {
259    scalings.i((.., .., layer, ..))
260}
261
262#[allow(clippy::too_many_arguments)]
263pub fn linear_b(
264    in_dim: usize,
265    out_dim: usize,
266    bias: bool,
267    base_vb: ShardedVarBuilder,
268    vb: ShardedVarBuilder,
269    lora_config: &[((String, String), LoraConfig)],
270    count: &mut usize,
271    ord: &Ordering,
272    preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
273) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
274    if bias {
275        linear(
276            in_dim,
277            out_dim,
278            base_vb,
279            vb,
280            lora_config,
281            count,
282            ord,
283            preload_adapters,
284        )
285    } else {
286        linear_no_bias(
287            in_dim,
288            out_dim,
289            base_vb,
290            vb,
291            lora_config,
292            count,
293            ord,
294            preload_adapters,
295        )
296    }
297}
298
299pub fn get_lora_cfg(tensor: &QTensor) -> LoraLinearConfig {
300    LoraLinearConfig::new(tensor.shape().dims()[1], tensor.shape().dims()[0])
301}