mistralrs_quant/bitsandbytes/
mod.rs1use std::{
2 borrow::Cow,
3 sync::{atomic::AtomicUsize, Arc},
4};
5
6use candle_core::{Context, DType, Device, Result, Shape, Tensor};
7use serde::Deserialize;
8
9use crate::{
10 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, ShardedVarBuilder,
11};
12
13#[cfg(feature = "cuda")]
14mod ffi;
15
16mod op;
17
18const SUPPORTED_BLOCKSIZE: [usize; 7] = [2048, 4096, 1024, 512, 256, 128, 64];
19
20#[derive(Debug, Deserialize, Clone, Copy)]
21pub enum BnbDType {
22 #[serde(rename = "float32")]
23 F32,
24 #[serde(rename = "bfloat16")]
25 BF16,
26 #[serde(rename = "float16")]
27 F16,
28}
29
30#[derive(Debug, Clone, Copy)]
31pub enum BnbQuantType {
32 Int8,
33 Fp4,
34 Nf4,
35}
36
37impl From<BnbDType> for DType {
38 fn from(value: BnbDType) -> Self {
39 match value {
40 BnbDType::F32 => Self::F32,
41 BnbDType::BF16 => Self::BF16,
42 BnbDType::F16 => Self::F16,
43 }
44 }
45}
46
47#[derive(Debug, Deserialize)]
48pub struct BnbQuantState {
49 pub blocksize: usize,
50 pub shape: Vec<usize>,
51 pub dtype: BnbDType,
52 pub nested_blocksize: Option<usize>,
53 pub nested_offset: Option<f64>,
54 pub nested_dtype: Option<BnbDType>,
55}
56
57#[derive(Debug, Clone)]
58pub struct BnbQuantParmas {
59 pub absmax: Tensor,
60 pub code: Tensor,
61 pub blocksize: usize,
62 pub shape: Option<Shape>,
63 pub nested: Option<Arc<BnbQuantParmas>>,
64 pub offset: Option<f64>,
65 pub dtype: BnbDType,
66}
67
68#[derive(Debug)]
69pub struct BnbLinear {
70 weight: Tensor,
71 bias: Option<Tensor>,
72 params: BnbQuantParmas,
73 quant_ty: BnbQuantType,
74}
75
76impl BnbLinear {
77 pub fn linear_b(
78 _in_dim: usize,
79 out_dim: usize,
80 bias: bool,
81 vb: ShardedVarBuilder,
82 ) -> Result<Self> {
83 let weight = vb.get_unchecked_dtype("weight", DType::U8)?;
84
85 let vb_w = vb.pp("weight");
86
87 if !vb_w.contains_tensor("quant_state.bitsandbytes__nf4")
88 && !vb_w.contains_tensor("quant_state.bitsandbytes__fp4")
89 {
90 candle_core::bail!("`BnbLinear` expects either `...__nf4` or `...__fp4` tensors, this means the layer is not 4bit.");
91 }
92
93 let bias = if bias {
94 Some(vb.get((out_dim,), "bias")?)
95 } else {
96 None
97 };
98
99 let quant_ty = if vb_w.contains_tensor("quant_state.bitsandbytes__nf4") {
100 BnbQuantType::Nf4
101 } else if vb_w.contains_tensor("quant_state.bitsandbytes__fp4") {
102 BnbQuantType::Fp4
103 } else {
104 BnbQuantType::Int8
105 };
106
107 let state = match quant_ty {
108 BnbQuantType::Nf4 => {
109 Some(vb_w.get_unchecked_dtype("quant_state.bitsandbytes__nf4", DType::U8)?)
110 }
111 BnbQuantType::Fp4 => {
112 Some(vb_w.get_unchecked_dtype("quant_state.bitsandbytes__fp4", DType::U8)?)
113 }
114 BnbQuantType::Int8 => None,
115 };
116 let Some(state) = state else {
117 candle_core::bail!("Only fp8/nf4 quantization is supported for now.")
118 };
119
120 let state_str = String::from_utf8(state.to_vec1::<u8>()?)?;
121 let state: BnbQuantState =
122 serde_json::from_str(&state_str).map_err(candle_core::Error::msg)?;
123
124 let nested = if vb_w.contains_tensor("nested_absmax") {
125 Some(Arc::new(BnbQuantParmas {
127 absmax: vb_w.get_unchecked_dtype("nested_absmax", DType::F32)?,
128 code: vb_w.get_unchecked_dtype("nested_quant_map", DType::F32)?,
129 blocksize: state
130 .nested_blocksize
131 .context("`nested_blocksize` must be present.")?,
132 shape: None,
133 nested: None,
134 offset: None, dtype: state
136 .nested_dtype
137 .context("`nested_dtype` must be present.")?,
138 }))
139 } else {
140 None
141 };
142
143 let absmax = if nested.is_some() {
144 vb_w.get_unchecked_dtype("absmax", DType::U8)?
145 } else {
146 vb_w.get_unchecked_dtype("absmax", DType::F32)?
147 };
148
149 let params = BnbQuantParmas {
150 absmax,
151 code: vb_w.get_unchecked_dtype("quant_map", DType::F32)?,
152 blocksize: state.blocksize,
153 shape: Some(Shape::from_dims(&state.shape)),
154 nested,
155 offset: state.nested_offset,
156 dtype: state.dtype,
157 };
158
159 Ok(Self {
160 weight,
161 bias,
162 params,
163 quant_ty,
164 })
165 }
166
167 fn dequantize(
169 input: &Tensor,
170 params: &BnbQuantParmas,
171 quant_ty: BnbQuantType,
172 ) -> Result<Tensor> {
173 let mut absmax = params.absmax.clone();
174 if let Some(nested) = ¶ms.nested {
175 absmax = Self::dequantize(¶ms.absmax, nested, BnbQuantType::Int8)?;
176 absmax = (absmax + params.offset.context("`offset` must be present.")?)?;
177 }
178
179 let out_shape = params.shape.clone().unwrap_or(input.shape().clone());
180 let out_dtype: DType = params.dtype.into();
181
182 if !SUPPORTED_BLOCKSIZE.contains(¶ms.blocksize) {
183 candle_core::bail!(
184 "Blocksize of {} is not supported, {SUPPORTED_BLOCKSIZE:?} are.",
185 params.blocksize
186 );
187 }
188
189 op::dequantize(
190 input,
191 &absmax,
192 ¶ms.code,
193 out_shape,
194 params.blocksize,
195 quant_ty,
196 params.dtype,
197 )?
198 .to_dtype(out_dtype)
199 }
200}
201
202impl QuantMethod for BnbLinear {
203 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
204 where
205 Self: Sized,
206 {
207 match method {
208 QuantMethodConfig::Gguf { .. }
209 | QuantMethodConfig::Gptq { .. }
210 | QuantMethodConfig::Hqq { .. }
211 | QuantMethodConfig::Dummy
212 | QuantMethodConfig::Unquantized(_)
213 | QuantMethodConfig::FP8 { .. }
214 | QuantMethodConfig::BlockwiseFP8 { .. }
215 | QuantMethodConfig::Afq { .. } => unreachable!(),
216 QuantMethodConfig::Bnb {
217 weight,
218 bias,
219 params,
220 quant_ty,
221 } => Ok(Self {
222 weight,
223 bias,
224 params,
225 quant_ty,
226 }),
227 }
228 }
229
230 fn dequantize_w(&self) -> Result<Tensor> {
231 Self::dequantize(&self.weight, &self.params, self.quant_ty)
232 }
233
234 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
235 let w = Self::dequantize(&self.weight, &self.params, self.quant_ty)?
236 .t()?
237 .to_dtype(xs.dtype())?;
238 let res = xs.broadcast_matmul(&w)?;
239 if let Some(bias) = &self.bias {
240 res.broadcast_add(bias)
241 } else {
242 Ok(res)
243 }
244 }
245
246 fn quantized_act_type(&self) -> Option<DType> {
247 None
248 }
249
250 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
251 candle_core::bail!("HQQ quantization does not support adding weight delta.")
252 }
253
254 fn dtype_and_device(&self) -> (DType, Device) {
255 (self.params.dtype.into(), self.weight.device().clone())
256 }
257
258 fn apply_isq(
259 self: Arc<Self>,
260 _dtype: Option<IsqType>,
261 _device: Device,
262 _n_quantized: &AtomicUsize,
263 _imatrix_weight: Option<Vec<f32>>,
264 _guard: QuantizeOntoGuard,
265 ) -> Result<Arc<dyn QuantMethod>> {
266 todo!()
267 }
268}
269
270impl QuantizedSerde for BnbLinear {
271 fn isq_serde_supported(&self) -> bool {
272 true
273 }
274 fn name(&self) -> &'static str {
275 "bnb-linear"
276 }
277 fn serialize(&self) -> Result<Cow<[u8]>> {
278 todo!()
279 }
280
281 fn deserialize(
282 _data: Cow<[u8]>,
283 _device: &Device,
284 _comm: &Arc<crate::Comm>,
285 _guard: QuantizeOntoGuard,
286 ) -> Result<Arc<dyn QuantMethod>>
287 where
288 Self: Sized,
289 {
290 todo!()
291 }
292}