1#![allow(clippy::excessive_precision)]
2
3use std::fmt::Debug;
4
5#[cfg(feature = "cuda")]
6use candle_core::cuda::{
7    cudarc::driver::{sys::CUstream, CudaSlice, DeviceRepr, ValidAsZeroBits},
8    CudaDevice,
9};
10
11use candle_core::{
12    backend::BackendStorage, CpuStorage, CustomOp3, Result, Shape, Tensor, WithDType,
13};
14
15#[cfg(feature = "cuda")]
16use crate::bitsandbytes::ffi;
17
18use super::{BnbDType, BnbQuantType};
19
20struct DequantizeOp {
21    n: usize,
22    blocksize: usize,
23    shape: Shape,
24    quant_ty: BnbQuantType,
25    out_ty: BnbDType,
26}
27
28fn d_dequantize_nf4(val: u8) -> f32 {
29    if (val & 0b1000) == 0b1000 {
32        if (val & 0b0100) == 0b0100 {
33            if (val & 0b0010) == 0b0010 {
35                if (val & 0b0001) == 0b0001 {
37                    1.0
39                } else {
40                    0.7229568362236023
41                }
42            } else if (val & 0b0001) == 0b0001 {
43                0.5626170039176941
45            } else {
46                0.44070982933044434
47            }
48        } else if (val & 0b0010) == 0b0010 {
49            if (val & 0b0001) == 0b0001 {
51                0.33791524171829224
53            } else {
54                0.24611230194568634
55            }
56        } else if (val & 0b0001) == 0b0001 {
57            0.16093020141124725
59        } else {
60            0.07958029955625534
61        }
62    } else if (val & 0b0100) == 0b0100 {
63        if (val & 0b0010) == 0b0010 {
65            if (val & 0b0001) == 0b0001 {
67                0.0
69            } else {
70                -0.09105003625154495
71            }
72        } else if (val & 0b0001) == 0b0001 {
73            -0.18477343022823334
75        } else {
76            -0.28444138169288635
77        }
78    } else if (val & 0b0010) == 0b0010 {
79        if (val & 0b0001) == 0b0001 {
81            -0.39491748809814453
83        } else {
84            -0.5250730514526367
85        }
86    } else if (val & 0b0001) == 0b0001 {
87        -0.6961928009986877
89    } else {
90        -1.0
91    }
92}
93
94fn d_dequantize_fp4_tree(val: u8, absmax: f32) -> f32 {
95    let sign = if (val & 0b1000) == 0b1000 { -1.0 } else { 1.0 };
96
97    if (val & 0b0100) == 0b0100 {
98        if (val & 0b0010) == 0b0010 {
100            if (val & 0b0001) == 0b0001 {
102                0.25000000 * absmax * sign } else {
105                0.16666667 * absmax * sign }
107        } else if (val & 0b0001) == 0b0001 {
108            0.50000000 * absmax * sign } else {
111            0.33333333 * absmax * sign }
113    } else if (val & 0b0010) == 0b0010 {
114        if (val & 0b0001) == 0b0001 {
116            1.00000000 * absmax * sign } else {
119            0.66666667 * absmax * sign }
121    } else if (val & 0b0001) == 0b0001 {
122        5.208333333e-03 * absmax * sign } else {
125        0.00000000 * absmax * sign }
127}
128
129impl DequantizeOp {
130    fn dequantize_cpu<T: WithDType + Debug>(
131        &self,
132        input: &[u8],
133        absmax: &[f32],
134        code: &[f32],
135        quant_ty: BnbQuantType,
136    ) -> Vec<T> {
137        match quant_ty {
138            BnbQuantType::Int8 => {
139                let mut out = vec![T::zero(); self.n];
140                for block_idx in (0..self.n).step_by(self.blocksize) {
141                    let valid_items = if self.n - block_idx >= self.blocksize {
142                        self.blocksize
143                    } else {
144                        self.n - block_idx
145                    };
146                    let block_end = block_idx + valid_items;
147                    for i in block_idx..block_end {
148                        out[i] = T::from_f64(
149                            (code[input[i] as usize] * absmax[block_idx / self.blocksize]) as f64,
150                        );
151                    }
152                }
153                out
154            }
155            BnbQuantType::Fp4 => {
156                let mut out = vec![T::zero(); self.shape.elem_count()];
157                for block_idx in (0..self.n).step_by(self.blocksize) {
158                    let valid_items = if self.n > self.blocksize + block_idx {
159                        self.blocksize
160                    } else {
161                        self.n - block_idx
162                    };
163                    let block_end = block_idx + valid_items;
164
165                    let local_abs_max = absmax[block_idx / self.blocksize];
166
167                    for i in block_idx..block_end {
168                        out[i * 2] =
169                            T::from_f64(d_dequantize_fp4_tree(input[i] >> 4, local_abs_max) as f64);
170                        out[i * 2 + 1] = T::from_f64(d_dequantize_fp4_tree(
171                            input[i] & 0x0F,
172                            local_abs_max,
173                        ) as f64);
174                    }
175                }
176                out
177            }
178            BnbQuantType::Nf4 => {
179                let mut out = vec![T::zero(); self.shape.elem_count()];
180                for block_idx in (0..self.n).step_by(self.blocksize) {
181                    let valid_items = if self.n > self.blocksize + block_idx {
182                        self.blocksize
183                    } else {
184                        self.n - block_idx
185                    };
186                    let block_end = block_idx + valid_items;
187
188                    let local_abs_max = absmax[block_idx / (self.blocksize / 2)];
189
190                    for i in block_idx..block_end {
191                        out[i * 2] =
192                            T::from_f64((d_dequantize_nf4(input[i] >> 4) * local_abs_max) as f64);
193                        out[i * 2 + 1] =
194                            T::from_f64((d_dequantize_nf4(input[i] & 0x0F) * local_abs_max) as f64);
195                    }
196                }
197                out
198            }
199        }
200    }
201
202    #[cfg(feature = "cuda")]
203    fn dispatch_cuda_kernel<T: WithDType + DeviceRepr + ValidAsZeroBits>(
204        &self,
205        input: &CudaSlice<u8>,
206        code: &CudaSlice<f32>,
207        absmax: &CudaSlice<f32>,
208        dev: &CudaDevice,
209        kernel: unsafe extern "C" fn(*const f32, *const u8, *const f32, *mut T, i32, i32, CUstream),
210    ) -> Result<CudaSlice<T>> {
211        use crate::utils::slice_ptr;
212
213        let out = unsafe { dev.alloc::<T>(self.shape.elem_count())? };
214
215        let (code, _code_guard) = slice_ptr(code, 0);
216        let (input, _input_guard) = slice_ptr(input, 0);
217        let (absmax, _absmax_guard) = slice_ptr(absmax, 0);
218        let (out_ptr, out_guard) = slice_ptr(&out, 0);
219
220        unsafe {
221            kernel(
222                code as *const _,
223                input as *const _,
224                absmax as *const _,
225                out_ptr as *mut _,
226                self.blocksize as i32,
227                self.shape.elem_count() as i32,
228                dev.cuda_stream().cu_stream(),
229            )
230        };
231
232        drop(out_guard);
233
234        Ok(out)
235    }
236}
237
238impl CustomOp3 for DequantizeOp {
239    fn name(&self) -> &'static str {
240        "dequantize-bnb"
241    }
242
243    fn cpu_fwd(
244        &self,
245        input_s: &CpuStorage,
246        input_l: &candle_core::Layout,
247        absmax_s: &CpuStorage,
248        absmax_l: &candle_core::Layout,
249        code_s: &CpuStorage,
250        code_l: &candle_core::Layout,
251    ) -> candle_core::Result<(CpuStorage, candle_core::Shape)> {
252        if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
253            candle_core::bail!("All inputs must be contiguous");
254        }
255        match (input_s, absmax_s, code_s, self.out_ty) {
256            (
257                CpuStorage::U8(input),
258                CpuStorage::F32(absmax),
259                CpuStorage::F32(code),
260                BnbDType::BF16,
261            ) => Ok((
262                CpuStorage::BF16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
263                self.shape.clone(),
264            )),
265            (
266                CpuStorage::U8(input),
267                CpuStorage::F32(absmax),
268                CpuStorage::F32(code),
269                BnbDType::F16,
270            ) => Ok((
271                CpuStorage::F16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
272                self.shape.clone(),
273            )),
274            (
275                CpuStorage::U8(input),
276                CpuStorage::F32(absmax),
277                CpuStorage::F32(code),
278                BnbDType::F32,
279            ) => Ok((
280                CpuStorage::F32(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
281                self.shape.clone(),
282            )),
283            (i, a, c, t) => candle_core::bail!(
284                "Unsupported dtypes for cpu dequant: {:?} input, {:?} absmax, {:?} code, {:?} out",
285                i.dtype(),
286                a.dtype(),
287                c.dtype(),
288                t
289            ),
290        }
291    }
292
293    #[cfg(feature = "cuda")]
294    fn cuda_fwd(
295        &self,
296        input_s: &candle_core::CudaStorage,
297        input_l: &candle_core::Layout,
298        absmax_s: &candle_core::CudaStorage,
299        absmax_l: &candle_core::Layout,
300        code_s: &candle_core::CudaStorage,
301        code_l: &candle_core::Layout,
302    ) -> Result<(candle_core::CudaStorage, Shape)> {
303        if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
304            candle_core::bail!("All inputs must be contiguous");
305        }
306        let input_slice = input_s.as_cuda_slice::<u8>()?;
307        let absmax_slice = absmax_s.as_cuda_slice::<f32>()?;
308        let code_slice = code_s.as_cuda_slice::<f32>()?;
309        let dev = input_s.device().clone();
310        let out = match (self.out_ty, self.quant_ty) {
311            (BnbDType::F32, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
312                self.dispatch_cuda_kernel::<f32>(
313                    input_slice,
314                    code_slice,
315                    absmax_slice,
316                    &dev,
317                    ffi::dequantize_blockwise_f32_nf4,
318                )?,
319                dev,
320            ),
321            (BnbDType::F16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
322                self.dispatch_cuda_kernel::<half::f16>(
323                    input_slice,
324                    code_slice,
325                    absmax_slice,
326                    &dev,
327                    ffi::dequantize_blockwise_f16_nf4,
328                )?,
329                dev,
330            ),
331            (BnbDType::BF16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
332                self.dispatch_cuda_kernel::<half::bf16>(
333                    input_slice,
334                    code_slice,
335                    absmax_slice,
336                    &dev,
337                    ffi::dequantize_blockwise_bf16_nf4,
338                )?,
339                dev,
340            ),
341
342            (BnbDType::F32, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
343                self.dispatch_cuda_kernel::<f32>(
344                    input_slice,
345                    code_slice,
346                    absmax_slice,
347                    &dev,
348                    ffi::dequantize_blockwise_f32_fp4,
349                )?,
350                dev,
351            ),
352            (BnbDType::F16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
353                self.dispatch_cuda_kernel::<half::f16>(
354                    input_slice,
355                    code_slice,
356                    absmax_slice,
357                    &dev,
358                    ffi::dequantize_blockwise_f16_fp4,
359                )?,
360                dev,
361            ),
362            (BnbDType::BF16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
363                self.dispatch_cuda_kernel::<half::bf16>(
364                    input_slice,
365                    code_slice,
366                    absmax_slice,
367                    &dev,
368                    ffi::dequantize_blockwise_bf16_fp4,
369                )?,
370                dev,
371            ),
372
373            (BnbDType::F32, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
374                self.dispatch_cuda_kernel::<f32>(
375                    input_slice,
376                    code_slice,
377                    absmax_slice,
378                    &dev,
379                    ffi::dequantize_blockwise_f32_int8,
380                )?,
381                dev,
382            ),
383            (BnbDType::F16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
384                self.dispatch_cuda_kernel::<half::f16>(
385                    input_slice,
386                    code_slice,
387                    absmax_slice,
388                    &dev,
389                    ffi::dequantize_blockwise_f16_int8,
390                )?,
391                dev,
392            ),
393            (BnbDType::BF16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
394                self.dispatch_cuda_kernel::<half::bf16>(
395                    input_slice,
396                    code_slice,
397                    absmax_slice,
398                    &dev,
399                    ffi::dequantize_blockwise_bf16_int8,
400                )?,
401                dev,
402            ),
403        };
404
405        Ok((out, self.shape.clone()))
406    }
407
408    #[cfg(feature = "metal")]
409    fn metal_fwd(
410        &self,
411        input_s: &candle_core::MetalStorage,
412        input_l: &candle_core::Layout,
413        absmax_s: &candle_core::MetalStorage,
414        absmax_l: &candle_core::Layout,
415        code_s: &candle_core::MetalStorage,
416        code_l: &candle_core::Layout,
417    ) -> Result<(candle_core::MetalStorage, Shape)> {
418        use candle_core::DType;
419
420        if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
421            candle_core::bail!("All inputs must be contiguous");
422        }
423
424        let command_buffer = input_s.device().command_buffer()?;
425        command_buffer.set_label("dequant-bnb-nf4");
426
427        let device = input_s.device();
428
429        let output = device.new_buffer(
430            self.shape.elem_count(),
431            self.out_ty.into(),
432            "dequant-bnb-nf4",
433        )?;
434
435        if input_s.dtype() != DType::U8 {
436            candle_core::bail!("input must be u8");
437        }
438        if code_s.dtype() != DType::F32 {
439            candle_core::bail!("code must be f32");
440        }
441        if absmax_s.dtype() != DType::F32 {
442            candle_core::bail!("absmax must be f32");
443        }
444
445        match self.quant_ty {
446            BnbQuantType::Nf4 => crate::metal_kernels::call_dequant_bnb_nf4(
447                device.device(),
448                &command_buffer,
449                &crate::metal_kernels::Kernels::new(),
450                self.out_ty.into(),
451                input_s.buffer(),
452                absmax_s.buffer(),
453                code_s.buffer(),
454                &output,
455                self.blocksize,
456                self.n,
457            )
458            .map_err(candle_core::Error::wrap)?,
459            BnbQuantType::Fp4 => crate::metal_kernels::call_dequant_bnb_fp4(
460                device.device(),
461                &command_buffer,
462                &crate::metal_kernels::Kernels::new(),
463                self.out_ty.into(),
464                input_s.buffer(),
465                absmax_s.buffer(),
466                code_s.buffer(),
467                &output,
468                self.blocksize,
469                self.n,
470            )
471            .map_err(candle_core::Error::wrap)?,
472            BnbQuantType::Int8 => crate::metal_kernels::call_dequant_bnb_int8(
473                device.device(),
474                &command_buffer,
475                &crate::metal_kernels::Kernels::new(),
476                self.out_ty.into(),
477                input_s.buffer(),
478                absmax_s.buffer(),
479                code_s.buffer(),
480                &output,
481                self.blocksize,
482                self.n,
483            )
484            .map_err(candle_core::Error::wrap)?,
485        };
486
487        let newstorage = candle_core::MetalStorage::new(
488            output,
489            device.clone(),
490            self.shape.elem_count(),
491            self.out_ty.into(),
492        );
493        Ok((newstorage, self.shape.clone()))
494    }
495}
496
497pub fn dequantize(
498    input: &Tensor,
499    absmax: &Tensor,
500    code: &Tensor,
501    shape: Shape,
502    blocksize: usize,
503    quant_ty: BnbQuantType,
504    out_ty: BnbDType,
505) -> Result<Tensor> {
506    input.apply_op3(
507        absmax,
508        code,
509        DequantizeOp {
510            n: input.elem_count(),
511            blocksize,
512            shape,
513            quant_ty,
514            out_ty,
515        },
516    )
517}