mistralrs_quant/bitsandbytes/
op.rs

1#![allow(clippy::excessive_precision)]
2
3use std::fmt::Debug;
4
5#[cfg(feature = "cuda")]
6use candle_core::cuda::{
7    cudarc::driver::{sys::CUstream, CudaSlice, DeviceRepr, ValidAsZeroBits},
8    CudaDevice,
9};
10
11use candle_core::{
12    backend::BackendStorage, CpuStorage, CustomOp3, Result, Shape, Tensor, WithDType,
13};
14
15#[cfg(feature = "cuda")]
16use crate::bitsandbytes::ffi;
17
18use super::{BnbDType, BnbQuantType};
19
20struct DequantizeOp {
21    n: usize,
22    blocksize: usize,
23    shape: Shape,
24    quant_ty: BnbQuantType,
25    out_ty: BnbDType,
26}
27
28fn d_dequantize_nf4(val: u8) -> f32 {
29    // the values for this tree were generated by test_normal_map_tree
30    // in the file tests/test_functional.py
31    if (val & 0b1000) == 0b1000 {
32        if (val & 0b0100) == 0b0100 {
33            // 1
34            if (val & 0b0010) == 0b0010 {
35                // 11
36                if (val & 0b0001) == 0b0001 {
37                    // 111
38                    1.0
39                } else {
40                    0.7229568362236023
41                }
42            } else if (val & 0b0001) == 0b0001 {
43                // 110
44                0.5626170039176941
45            } else {
46                0.44070982933044434
47            }
48        } else if (val & 0b0010) == 0b0010 {
49            // 10
50            if (val & 0b0001) == 0b0001 {
51                // 101
52                0.33791524171829224
53            } else {
54                0.24611230194568634
55            }
56        } else if (val & 0b0001) == 0b0001 {
57            // 100
58            0.16093020141124725
59        } else {
60            0.07958029955625534
61        }
62    } else if (val & 0b0100) == 0b0100 {
63        // 0
64        if (val & 0b0010) == 0b0010 {
65            // 01
66            if (val & 0b0001) == 0b0001 {
67                // 011
68                0.0
69            } else {
70                -0.09105003625154495
71            }
72        } else if (val & 0b0001) == 0b0001 {
73            // 010
74            -0.18477343022823334
75        } else {
76            -0.28444138169288635
77        }
78    } else if (val & 0b0010) == 0b0010 {
79        // 00
80        if (val & 0b0001) == 0b0001 {
81            // 001
82            -0.39491748809814453
83        } else {
84            -0.5250730514526367
85        }
86    } else if (val & 0b0001) == 0b0001 {
87        // 000
88        -0.6961928009986877
89    } else {
90        -1.0
91    }
92}
93
94fn d_dequantize_fp4_tree(val: u8, absmax: f32) -> f32 {
95    let sign = if (val & 0b1000) == 0b1000 { -1.0 } else { 1.0 };
96
97    if (val & 0b0100) == 0b0100 {
98        // 0
99        if (val & 0b0010) == 0b0010 {
100            // 01
101            if (val & 0b0001) == 0b0001 {
102                // 111
103                0.25000000 * absmax * sign // 1111
104            } else {
105                0.16666667 * absmax * sign // 1110
106            }
107        } else if (val & 0b0001) == 0b0001 {
108            // 110
109            0.50000000 * absmax * sign // 1101
110        } else {
111            0.33333333 * absmax * sign // 1100
112        }
113    } else if (val & 0b0010) == 0b0010 {
114        // 10
115        if (val & 0b0001) == 0b0001 {
116            // 101
117            1.00000000 * absmax * sign // 1011
118        } else {
119            0.66666667 * absmax * sign // 1010
120        }
121    } else if (val & 0b0001) == 0b0001 {
122        // 100
123        5.208333333e-03 * absmax * sign // 1001
124    } else {
125        0.00000000 * absmax * sign // 1000
126    }
127}
128
129impl DequantizeOp {
130    fn dequantize_cpu<T: WithDType + Debug>(
131        &self,
132        input: &[u8],
133        absmax: &[f32],
134        code: &[f32],
135        quant_ty: BnbQuantType,
136    ) -> Vec<T> {
137        match quant_ty {
138            BnbQuantType::Int8 => {
139                let mut out = vec![T::zero(); self.n];
140                for block_idx in (0..self.n).step_by(self.blocksize) {
141                    let valid_items = if self.n - block_idx >= self.blocksize {
142                        self.blocksize
143                    } else {
144                        self.n - block_idx
145                    };
146                    let block_end = block_idx + valid_items;
147                    for i in block_idx..block_end {
148                        out[i] = T::from_f64(
149                            (code[input[i] as usize] * absmax[block_idx / self.blocksize]) as f64,
150                        );
151                    }
152                }
153                out
154            }
155            BnbQuantType::Fp4 => {
156                let mut out = vec![T::zero(); self.shape.elem_count()];
157                for block_idx in (0..self.n).step_by(self.blocksize) {
158                    let valid_items = if self.n > self.blocksize + block_idx {
159                        self.blocksize
160                    } else {
161                        self.n - block_idx
162                    };
163                    let block_end = block_idx + valid_items;
164
165                    let local_abs_max = absmax[block_idx / self.blocksize];
166
167                    for i in block_idx..block_end {
168                        out[i * 2] =
169                            T::from_f64(d_dequantize_fp4_tree(input[i] >> 4, local_abs_max) as f64);
170                        out[i * 2 + 1] = T::from_f64(d_dequantize_fp4_tree(
171                            input[i] & 0x0F,
172                            local_abs_max,
173                        ) as f64);
174                    }
175                }
176                out
177            }
178            BnbQuantType::Nf4 => {
179                let mut out = vec![T::zero(); self.shape.elem_count()];
180                for block_idx in (0..self.n).step_by(self.blocksize) {
181                    let valid_items = if self.n > self.blocksize + block_idx {
182                        self.blocksize
183                    } else {
184                        self.n - block_idx
185                    };
186                    let block_end = block_idx + valid_items;
187
188                    let local_abs_max = absmax[block_idx / (self.blocksize / 2)];
189
190                    for i in block_idx..block_end {
191                        out[i * 2] =
192                            T::from_f64((d_dequantize_nf4(input[i] >> 4) * local_abs_max) as f64);
193                        out[i * 2 + 1] =
194                            T::from_f64((d_dequantize_nf4(input[i] & 0x0F) * local_abs_max) as f64);
195                    }
196                }
197                out
198            }
199        }
200    }
201
202    #[cfg(feature = "cuda")]
203    fn dispatch_cuda_kernel<T: WithDType + DeviceRepr + ValidAsZeroBits>(
204        &self,
205        input: &CudaSlice<u8>,
206        code: &CudaSlice<f32>,
207        absmax: &CudaSlice<f32>,
208        dev: &CudaDevice,
209        kernel: unsafe extern "C" fn(*const f32, *const u8, *const f32, *mut T, i32, i32, CUstream),
210    ) -> Result<CudaSlice<T>> {
211        use candle_core::cuda::{cudarc::driver::DevicePtr, WrapErr};
212
213        let out = unsafe { dev.alloc::<T>(self.shape.elem_count()).w()? };
214        unsafe {
215            kernel(
216                (*code.device_ptr()) as *const _,
217                (*input.device_ptr()) as *const _,
218                (*absmax.device_ptr()) as *const _,
219                (*out.device_ptr()) as *mut _,
220                self.blocksize as i32,
221                self.shape.elem_count() as i32,
222                *dev.cu_stream(),
223            )
224        };
225
226        Ok(out)
227    }
228}
229
230impl CustomOp3 for DequantizeOp {
231    fn name(&self) -> &'static str {
232        "dequantize-bnb"
233    }
234
235    fn cpu_fwd(
236        &self,
237        input_s: &CpuStorage,
238        input_l: &candle_core::Layout,
239        absmax_s: &CpuStorage,
240        absmax_l: &candle_core::Layout,
241        code_s: &CpuStorage,
242        code_l: &candle_core::Layout,
243    ) -> candle_core::Result<(CpuStorage, candle_core::Shape)> {
244        if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
245            candle_core::bail!("All inputs must be contiguous");
246        }
247        match (input_s, absmax_s, code_s, self.out_ty) {
248            (
249                CpuStorage::U8(input),
250                CpuStorage::F32(absmax),
251                CpuStorage::F32(code),
252                BnbDType::BF16,
253            ) => Ok((
254                CpuStorage::BF16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
255                self.shape.clone(),
256            )),
257            (
258                CpuStorage::U8(input),
259                CpuStorage::F32(absmax),
260                CpuStorage::F32(code),
261                BnbDType::F16,
262            ) => Ok((
263                CpuStorage::F16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
264                self.shape.clone(),
265            )),
266            (
267                CpuStorage::U8(input),
268                CpuStorage::F32(absmax),
269                CpuStorage::F32(code),
270                BnbDType::F32,
271            ) => Ok((
272                CpuStorage::F32(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
273                self.shape.clone(),
274            )),
275            (i, a, c, t) => candle_core::bail!(
276                "Unsupported dtypes for cpu dequant: {:?} input, {:?} absmax, {:?} code, {:?} out",
277                i.dtype(),
278                a.dtype(),
279                c.dtype(),
280                t
281            ),
282        }
283    }
284
285    #[cfg(feature = "cuda")]
286    fn cuda_fwd(
287        &self,
288        input_s: &candle_core::CudaStorage,
289        input_l: &candle_core::Layout,
290        absmax_s: &candle_core::CudaStorage,
291        absmax_l: &candle_core::Layout,
292        code_s: &candle_core::CudaStorage,
293        code_l: &candle_core::Layout,
294    ) -> Result<(candle_core::CudaStorage, Shape)> {
295        if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
296            candle_core::bail!("All inputs must be contiguous");
297        }
298        let input_slice = input_s.as_cuda_slice::<u8>()?;
299        let absmax_slice = absmax_s.as_cuda_slice::<f32>()?;
300        let code_slice = code_s.as_cuda_slice::<f32>()?;
301        let dev = input_s.device().clone();
302        let out = match (self.out_ty, self.quant_ty) {
303            (BnbDType::F32, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
304                self.dispatch_cuda_kernel::<f32>(
305                    input_slice,
306                    code_slice,
307                    absmax_slice,
308                    &dev,
309                    ffi::dequantize_blockwise_f32_nf4,
310                )?,
311                dev,
312            ),
313            (BnbDType::F16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
314                self.dispatch_cuda_kernel::<half::f16>(
315                    input_slice,
316                    code_slice,
317                    absmax_slice,
318                    &dev,
319                    ffi::dequantize_blockwise_f16_nf4,
320                )?,
321                dev,
322            ),
323            (BnbDType::BF16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
324                self.dispatch_cuda_kernel::<half::bf16>(
325                    input_slice,
326                    code_slice,
327                    absmax_slice,
328                    &dev,
329                    ffi::dequantize_blockwise_bf16_nf4,
330                )?,
331                dev,
332            ),
333
334            (BnbDType::F32, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
335                self.dispatch_cuda_kernel::<f32>(
336                    input_slice,
337                    code_slice,
338                    absmax_slice,
339                    &dev,
340                    ffi::dequantize_blockwise_f32_fp4,
341                )?,
342                dev,
343            ),
344            (BnbDType::F16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
345                self.dispatch_cuda_kernel::<half::f16>(
346                    input_slice,
347                    code_slice,
348                    absmax_slice,
349                    &dev,
350                    ffi::dequantize_blockwise_f16_fp4,
351                )?,
352                dev,
353            ),
354            (BnbDType::BF16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
355                self.dispatch_cuda_kernel::<half::bf16>(
356                    input_slice,
357                    code_slice,
358                    absmax_slice,
359                    &dev,
360                    ffi::dequantize_blockwise_bf16_fp4,
361                )?,
362                dev,
363            ),
364
365            (BnbDType::F32, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
366                self.dispatch_cuda_kernel::<f32>(
367                    input_slice,
368                    code_slice,
369                    absmax_slice,
370                    &dev,
371                    ffi::dequantize_blockwise_f32_int8,
372                )?,
373                dev,
374            ),
375            (BnbDType::F16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
376                self.dispatch_cuda_kernel::<half::f16>(
377                    input_slice,
378                    code_slice,
379                    absmax_slice,
380                    &dev,
381                    ffi::dequantize_blockwise_f16_int8,
382                )?,
383                dev,
384            ),
385            (BnbDType::BF16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
386                self.dispatch_cuda_kernel::<half::bf16>(
387                    input_slice,
388                    code_slice,
389                    absmax_slice,
390                    &dev,
391                    ffi::dequantize_blockwise_bf16_int8,
392                )?,
393                dev,
394            ),
395        };
396
397        Ok((out, self.shape.clone()))
398    }
399
400    #[cfg(feature = "metal")]
401    fn metal_fwd(
402        &self,
403        input_s: &candle_core::MetalStorage,
404        input_l: &candle_core::Layout,
405        absmax_s: &candle_core::MetalStorage,
406        absmax_l: &candle_core::Layout,
407        code_s: &candle_core::MetalStorage,
408        code_l: &candle_core::Layout,
409    ) -> Result<(candle_core::MetalStorage, Shape)> {
410        use candle_core::DType;
411
412        if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
413            candle_core::bail!("All inputs must be contiguous");
414        }
415
416        let command_buffer = input_s.device().command_buffer()?;
417        command_buffer.set_label("dequant-bnb-nf4");
418
419        let device = input_s.device();
420
421        let output = device.new_buffer(
422            self.shape.elem_count(),
423            self.out_ty.into(),
424            "dequant-bnb-nf4",
425        )?;
426
427        if input_s.dtype() != DType::U8 {
428            candle_core::bail!("input must be u8");
429        }
430        if code_s.dtype() != DType::F32 {
431            candle_core::bail!("code must be f32");
432        }
433        if absmax_s.dtype() != DType::F32 {
434            candle_core::bail!("absmax must be f32");
435        }
436
437        match self.quant_ty {
438            BnbQuantType::Nf4 => crate::metal_kernels::call_dequant_bnb_nf4(
439                device.device(),
440                &command_buffer,
441                &crate::metal_kernels::Kernels::new(),
442                self.out_ty.into(),
443                input_s.buffer(),
444                absmax_s.buffer(),
445                code_s.buffer(),
446                &output,
447                self.blocksize,
448                self.n,
449            )
450            .map_err(candle_core::Error::wrap)?,
451            BnbQuantType::Fp4 => crate::metal_kernels::call_dequant_bnb_fp4(
452                device.device(),
453                &command_buffer,
454                &crate::metal_kernels::Kernels::new(),
455                self.out_ty.into(),
456                input_s.buffer(),
457                absmax_s.buffer(),
458                code_s.buffer(),
459                &output,
460                self.blocksize,
461                self.n,
462            )
463            .map_err(candle_core::Error::wrap)?,
464            BnbQuantType::Int8 => crate::metal_kernels::call_dequant_bnb_int8(
465                device.device(),
466                &command_buffer,
467                &crate::metal_kernels::Kernels::new(),
468                self.out_ty.into(),
469                input_s.buffer(),
470                absmax_s.buffer(),
471                code_s.buffer(),
472                &output,
473                self.blocksize,
474                self.n,
475            )
476            .map_err(candle_core::Error::wrap)?,
477        };
478
479        let newstorage = candle_core::MetalStorage::new(
480            output,
481            device.clone(),
482            self.shape.elem_count(),
483            self.out_ty.into(),
484        );
485        Ok((newstorage, self.shape.clone()))
486    }
487}
488
489pub fn dequantize(
490    input: &Tensor,
491    absmax: &Tensor,
492    code: &Tensor,
493    shape: Shape,
494    blocksize: usize,
495    quant_ty: BnbQuantType,
496    out_ty: BnbDType,
497) -> Result<Tensor> {
498    input.apply_op3(
499        absmax,
500        code,
501        DequantizeOp {
502            n: input.elem_count(),
503            blocksize,
504            shape,
505            quant_ty,
506            out_ty,
507        },
508    )
509}