mistralrs_quant/bitsandbytes/
mod.rs

1use std::{
2    borrow::Cow,
3    sync::{atomic::AtomicUsize, Arc},
4};
5
6use candle_core::{Context, DType, Device, Result, Shape, Tensor};
7use serde::Deserialize;
8
9use crate::{
10    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, ShardedVarBuilder,
11};
12
13#[cfg(feature = "cuda")]
14mod ffi;
15
16mod op;
17
18const SUPPORTED_BLOCKSIZE: [usize; 7] = [2048, 4096, 1024, 512, 256, 128, 64];
19
20#[derive(Debug, Deserialize, Clone, Copy)]
21pub enum BnbDType {
22    #[serde(rename = "float32")]
23    F32,
24    #[serde(rename = "bfloat16")]
25    BF16,
26    #[serde(rename = "float16")]
27    F16,
28}
29
30#[derive(Debug, Clone, Copy)]
31pub enum BnbQuantType {
32    Int8,
33    Fp4,
34    Nf4,
35}
36
37impl From<BnbDType> for DType {
38    fn from(value: BnbDType) -> Self {
39        match value {
40            BnbDType::F32 => Self::F32,
41            BnbDType::BF16 => Self::BF16,
42            BnbDType::F16 => Self::F16,
43        }
44    }
45}
46
47#[derive(Debug, Deserialize)]
48pub struct BnbQuantState {
49    pub blocksize: usize,
50    pub shape: Vec<usize>,
51    pub dtype: BnbDType,
52    pub nested_blocksize: Option<usize>,
53    pub nested_offset: Option<f64>,
54    pub nested_dtype: Option<BnbDType>,
55}
56
57#[derive(Debug, Clone)]
58pub struct BnbQuantParmas {
59    pub absmax: Tensor,
60    pub code: Tensor,
61    pub blocksize: usize,
62    pub shape: Option<Shape>,
63    pub nested: Option<Arc<BnbQuantParmas>>,
64    pub offset: Option<f64>,
65    pub dtype: BnbDType,
66}
67
68#[derive(Debug)]
69pub struct BnbLinear {
70    weight: Tensor,
71    bias: Option<Tensor>,
72    params: BnbQuantParmas,
73    quant_ty: BnbQuantType,
74}
75
76impl BnbLinear {
77    pub fn linear_b(
78        _in_dim: usize,
79        out_dim: usize,
80        bias: bool,
81        vb: ShardedVarBuilder,
82    ) -> Result<Self> {
83        let weight = vb.get_unchecked_dtype("weight", DType::U8)?;
84
85        let vb_w = vb.pp("weight");
86
87        if !vb_w.contains_tensor("quant_state.bitsandbytes__nf4")
88            && !vb_w.contains_tensor("quant_state.bitsandbytes__fp4")
89        {
90            candle_core::bail!("`BnbLinear` expects either `...__nf4` or `...__fp4` tensors, this means the layer is not 4bit.");
91        }
92
93        let bias = if bias {
94            Some(vb.get((out_dim,), "bias")?)
95        } else {
96            None
97        };
98
99        let quant_ty = if vb_w.contains_tensor("quant_state.bitsandbytes__nf4") {
100            BnbQuantType::Nf4
101        } else if vb_w.contains_tensor("quant_state.bitsandbytes__fp4") {
102            BnbQuantType::Fp4
103        } else {
104            BnbQuantType::Int8
105        };
106
107        let state = match quant_ty {
108            BnbQuantType::Nf4 => {
109                Some(vb_w.get_unchecked_dtype("quant_state.bitsandbytes__nf4", DType::U8)?)
110            }
111            BnbQuantType::Fp4 => {
112                Some(vb_w.get_unchecked_dtype("quant_state.bitsandbytes__fp4", DType::U8)?)
113            }
114            BnbQuantType::Int8 => None,
115        };
116        let Some(state) = state else {
117            candle_core::bail!("Only fp8/nf4 quantization is supported for now.")
118        };
119
120        let state_str = String::from_utf8(state.to_vec1::<u8>()?)?;
121        let state: BnbQuantState =
122            serde_json::from_str(&state_str).map_err(candle_core::Error::msg)?;
123
124        let nested = if vb_w.contains_tensor("nested_absmax") {
125            // TODO: can `nested_blocksize` be None, default to 64 like bnb?
126            Some(Arc::new(BnbQuantParmas {
127                absmax: vb_w.get_unchecked_dtype("nested_absmax", DType::F32)?,
128                code: vb_w.get_unchecked_dtype("nested_quant_map", DType::F32)?,
129                blocksize: state
130                    .nested_blocksize
131                    .context("`nested_blocksize` must be present.")?,
132                shape: None,
133                nested: None,
134                offset: None, // Put it in the outer one!
135                dtype: state
136                    .nested_dtype
137                    .context("`nested_dtype` must be present.")?,
138            }))
139        } else {
140            None
141        };
142
143        let absmax = if nested.is_some() {
144            vb_w.get_unchecked_dtype("absmax", DType::U8)?
145        } else {
146            vb_w.get_unchecked_dtype("absmax", DType::F32)?
147        };
148
149        let params = BnbQuantParmas {
150            absmax,
151            code: vb_w.get_unchecked_dtype("quant_map", DType::F32)?,
152            blocksize: state.blocksize,
153            shape: Some(Shape::from_dims(&state.shape)),
154            nested,
155            offset: state.nested_offset,
156            dtype: state.dtype,
157        };
158
159        Ok(Self {
160            weight,
161            bias,
162            params,
163            quant_ty,
164        })
165    }
166
167    /// Dequantize input (u8). Handles nested absmax dequantization.
168    fn dequantize(
169        input: &Tensor,
170        params: &BnbQuantParmas,
171        quant_ty: BnbQuantType,
172    ) -> Result<Tensor> {
173        let mut absmax = params.absmax.clone();
174        if let Some(nested) = &params.nested {
175            absmax = Self::dequantize(&params.absmax, nested, BnbQuantType::Int8)?;
176            absmax = (absmax + params.offset.context("`offset` must be present.")?)?;
177        }
178
179        let out_shape = params.shape.clone().unwrap_or(input.shape().clone());
180        let out_dtype: DType = params.dtype.into();
181
182        if !SUPPORTED_BLOCKSIZE.contains(&params.blocksize) {
183            candle_core::bail!(
184                "Blocksize of {} is not supported, {SUPPORTED_BLOCKSIZE:?} are.",
185                params.blocksize
186            );
187        }
188
189        op::dequantize(
190            input,
191            &absmax,
192            &params.code,
193            out_shape,
194            params.blocksize,
195            quant_ty,
196            params.dtype,
197        )?
198        .to_dtype(out_dtype)
199    }
200}
201
202impl QuantMethod for BnbLinear {
203    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
204    where
205        Self: Sized,
206    {
207        match method {
208            QuantMethodConfig::Gguf { .. }
209            | QuantMethodConfig::Gptq { .. }
210            | QuantMethodConfig::Hqq { .. }
211            | QuantMethodConfig::Dummy
212            | QuantMethodConfig::Unquantized(_)
213            | QuantMethodConfig::FP8 { .. }
214            | QuantMethodConfig::BlockwiseFP8 { .. }
215            | QuantMethodConfig::Afq { .. } => unreachable!(),
216            QuantMethodConfig::Bnb {
217                weight,
218                bias,
219                params,
220                quant_ty,
221            } => Ok(Self {
222                weight,
223                bias,
224                params,
225                quant_ty,
226            }),
227        }
228    }
229
230    fn dequantize_w(&self) -> Result<Tensor> {
231        Self::dequantize(&self.weight, &self.params, self.quant_ty)
232    }
233
234    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
235        let w = Self::dequantize(&self.weight, &self.params, self.quant_ty)?
236            .t()?
237            .to_dtype(xs.dtype())?;
238        let res = xs.broadcast_matmul(&w)?;
239        if let Some(bias) = &self.bias {
240            res.broadcast_add(bias)
241        } else {
242            Ok(res)
243        }
244    }
245
246    fn quantized_act_type(&self) -> Option<DType> {
247        None
248    }
249
250    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
251        candle_core::bail!("HQQ quantization does not support adding weight delta.")
252    }
253
254    fn dtype_and_device(&self) -> (DType, Device) {
255        (self.params.dtype.into(), self.weight.device().clone())
256    }
257
258    fn apply_isq(
259        self: Arc<Self>,
260        _dtype: Option<IsqType>,
261        _device: Device,
262        _n_quantized: &AtomicUsize,
263        _imatrix_weight: Option<Vec<f32>>,
264        _guard: QuantizeOntoGuard,
265    ) -> Result<Arc<dyn QuantMethod>> {
266        todo!()
267    }
268}
269
270impl QuantizedSerde for BnbLinear {
271    fn isq_serde_supported(&self) -> bool {
272        true
273    }
274    fn name(&self) -> &'static str {
275        "bnb-linear"
276    }
277    fn serialize(&self) -> Result<Cow<[u8]>> {
278        todo!()
279    }
280
281    fn deserialize(
282        _data: Cow<[u8]>,
283        _device: &Device,
284        _comm: &Arc<crate::Comm>,
285        _guard: QuantizeOntoGuard,
286    ) -> Result<Arc<dyn QuantMethod>>
287    where
288        Self: Sized,
289    {
290        todo!()
291    }
292}