#![allow(clippy::excessive_precision)]
use std::fmt::Debug;
#[cfg(feature = "cuda")]
use candle_core::cuda::{
cudarc::driver::{sys::CUstream, CudaSlice, DeviceRepr, ValidAsZeroBits},
CudaDevice,
};
use candle_core::{
backend::BackendStorage, CpuStorage, CustomOp3, Result, Shape, Tensor, WithDType,
};
#[cfg(feature = "cuda")]
use crate::bitsandbytes::ffi;
use super::{BnbDType, BnbQuantType};
struct DequantizeOp {
n: usize,
blocksize: usize,
shape: Shape,
quant_ty: BnbQuantType,
out_ty: BnbDType,
}
fn d_dequantize_nf4(val: u8) -> f32 {
if (val & 0b1000) == 0b1000 {
if (val & 0b0100) == 0b0100 {
if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
1.0
} else {
0.7229568362236023
}
} else if (val & 0b0001) == 0b0001 {
0.5626170039176941
} else {
0.44070982933044434
}
} else if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
0.33791524171829224
} else {
0.24611230194568634
}
} else if (val & 0b0001) == 0b0001 {
0.16093020141124725
} else {
0.07958029955625534
}
} else if (val & 0b0100) == 0b0100 {
if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
0.0
} else {
-0.09105003625154495
}
} else if (val & 0b0001) == 0b0001 {
-0.18477343022823334
} else {
-0.28444138169288635
}
} else if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
-0.39491748809814453
} else {
-0.5250730514526367
}
} else if (val & 0b0001) == 0b0001 {
-0.6961928009986877
} else {
-1.0
}
}
fn d_dequantize_fp4_tree(val: u8, absmax: f32) -> f32 {
let sign = if (val & 0b1000) == 0b1000 { -1.0 } else { 1.0 };
if (val & 0b0100) == 0b0100 {
if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
0.25000000 * absmax * sign } else {
0.16666667 * absmax * sign }
} else if (val & 0b0001) == 0b0001 {
0.50000000 * absmax * sign } else {
0.33333333 * absmax * sign }
} else if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
1.00000000 * absmax * sign } else {
0.66666667 * absmax * sign }
} else if (val & 0b0001) == 0b0001 {
5.208333333e-03 * absmax * sign } else {
0.00000000 * absmax * sign }
}
impl DequantizeOp {
fn dequantize_cpu<T: WithDType + Debug>(
&self,
input: &[u8],
absmax: &[f32],
code: &[f32],
quant_ty: BnbQuantType,
) -> Vec<T> {
match quant_ty {
BnbQuantType::Int8 => {
let mut out = vec![T::zero(); self.n];
for block_idx in (0..self.n).step_by(self.blocksize) {
let valid_items = if self.n - block_idx >= self.blocksize {
self.blocksize
} else {
self.n - block_idx
};
let block_end = block_idx + valid_items;
for i in block_idx..block_end {
out[i] = T::from_f64(
(code[input[i] as usize] * absmax[block_idx / self.blocksize]) as f64,
);
}
}
out
}
BnbQuantType::Fp4 => {
let mut out = vec![T::zero(); self.shape.elem_count()];
for block_idx in (0..self.n).step_by(self.blocksize) {
let valid_items = if self.n > self.blocksize + block_idx {
self.blocksize
} else {
self.n - block_idx
};
let block_end = block_idx + valid_items;
let local_abs_max = absmax[block_idx / self.blocksize];
for i in block_idx..block_end {
out[i * 2] =
T::from_f64(d_dequantize_fp4_tree(input[i] >> 4, local_abs_max) as f64);
out[i * 2 + 1] = T::from_f64(d_dequantize_fp4_tree(
input[i] & 0x0F,
local_abs_max,
) as f64);
}
}
out
}
BnbQuantType::Nf4 => {
let mut out = vec![T::zero(); self.shape.elem_count()];
for block_idx in (0..self.n).step_by(self.blocksize) {
let valid_items = if self.n > self.blocksize + block_idx {
self.blocksize
} else {
self.n - block_idx
};
let block_end = block_idx + valid_items;
let local_abs_max = absmax[block_idx / (self.blocksize / 2)];
for i in block_idx..block_end {
out[i * 2] =
T::from_f64((d_dequantize_nf4(input[i] >> 4) * local_abs_max) as f64);
out[i * 2 + 1] =
T::from_f64((d_dequantize_nf4(input[i] & 0x0F) * local_abs_max) as f64);
}
}
out
}
}
}
#[cfg(feature = "cuda")]
fn dispatch_cuda_kernel<T: WithDType + DeviceRepr + ValidAsZeroBits>(
&self,
input: &CudaSlice<u8>,
code: &CudaSlice<f32>,
absmax: &CudaSlice<f32>,
dev: &CudaDevice,
kernel: unsafe extern "C" fn(*const f32, *const u8, *const f32, *mut T, i32, i32, CUstream),
) -> Result<CudaSlice<T>> {
use candle_core::cuda::{cudarc::driver::DevicePtr, WrapErr};
let out = unsafe { dev.alloc::<T>(self.shape.elem_count()).w()? };
unsafe {
kernel(
(*code.device_ptr()) as *const _,
(*input.device_ptr()) as *const _,
(*absmax.device_ptr()) as *const _,
(*out.device_ptr()) as *mut _,
self.blocksize as i32,
self.shape.elem_count() as i32,
*dev.cu_stream(),
)
};
Ok(out)
}
}
impl CustomOp3 for DequantizeOp {
fn name(&self) -> &'static str {
"dequantize-bnb"
}
fn cpu_fwd(
&self,
input_s: &CpuStorage,
input_l: &candle_core::Layout,
absmax_s: &CpuStorage,
absmax_l: &candle_core::Layout,
code_s: &CpuStorage,
code_l: &candle_core::Layout,
) -> candle_core::Result<(CpuStorage, candle_core::Shape)> {
if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
candle_core::bail!("All inputs must be contiguous");
}
match (input_s, absmax_s, code_s, self.out_ty) {
(
CpuStorage::U8(input),
CpuStorage::F32(absmax),
CpuStorage::F32(code),
BnbDType::BF16,
) => Ok((
CpuStorage::BF16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
self.shape.clone(),
)),
(
CpuStorage::U8(input),
CpuStorage::F32(absmax),
CpuStorage::F32(code),
BnbDType::F16,
) => Ok((
CpuStorage::F16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
self.shape.clone(),
)),
(
CpuStorage::U8(input),
CpuStorage::F32(absmax),
CpuStorage::F32(code),
BnbDType::F32,
) => Ok((
CpuStorage::F32(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
self.shape.clone(),
)),
(i, a, c, t) => candle_core::bail!(
"Unsupported dtypes for cpu dequant: {:?} input, {:?} absmax, {:?} code, {:?} out",
i.dtype(),
a.dtype(),
c.dtype(),
t
),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
input_s: &candle_core::CudaStorage,
input_l: &candle_core::Layout,
absmax_s: &candle_core::CudaStorage,
absmax_l: &candle_core::Layout,
code_s: &candle_core::CudaStorage,
code_l: &candle_core::Layout,
) -> Result<(candle_core::CudaStorage, Shape)> {
if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
candle_core::bail!("All inputs must be contiguous");
}
let input_slice = input_s.as_cuda_slice::<u8>()?;
let absmax_slice = absmax_s.as_cuda_slice::<f32>()?;
let code_slice = code_s.as_cuda_slice::<f32>()?;
let dev = input_s.device().clone();
let out = match (self.out_ty, self.quant_ty) {
(BnbDType::F32, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<f32>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f32_nf4,
)?,
dev,
),
(BnbDType::F16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::f16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f16_nf4,
)?,
dev,
),
(BnbDType::BF16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::bf16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_bf16_nf4,
)?,
dev,
),
(BnbDType::F32, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<f32>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f32_fp4,
)?,
dev,
),
(BnbDType::F16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::f16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f16_fp4,
)?,
dev,
),
(BnbDType::BF16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::bf16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_bf16_fp4,
)?,
dev,
),
(BnbDType::F32, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<f32>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f32_int8,
)?,
dev,
),
(BnbDType::F16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::f16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f16_int8,
)?,
dev,
),
(BnbDType::BF16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::bf16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_bf16_int8,
)?,
dev,
),
};
Ok((out, self.shape.clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
input_s: &candle_core::MetalStorage,
input_l: &candle_core::Layout,
absmax_s: &candle_core::MetalStorage,
absmax_l: &candle_core::Layout,
code_s: &candle_core::MetalStorage,
code_l: &candle_core::Layout,
) -> Result<(candle_core::MetalStorage, Shape)> {
use candle_core::DType;
if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
candle_core::bail!("All inputs must be contiguous");
}
let command_buffer = input_s.device().command_buffer()?;
command_buffer.set_label("dequant-bnb-nf4");
let device = input_s.device();
let output = device.new_buffer(
self.shape.elem_count(),
self.out_ty.into(),
"dequant-bnb-nf4",
)?;
if input_s.dtype() != DType::U8 {
candle_core::bail!("input must be u8");
}
if code_s.dtype() != DType::F32 {
candle_core::bail!("code must be f32");
}
if absmax_s.dtype() != DType::F32 {
candle_core::bail!("absmax must be f32");
}
match self.quant_ty {
BnbQuantType::Nf4 => crate::metal_kernels::call_dequant_bnb_nf4(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
self.out_ty.into(),
input_s.buffer(),
absmax_s.buffer(),
code_s.buffer(),
&output,
self.blocksize,
self.n,
)
.map_err(candle_core::Error::wrap)?,
BnbQuantType::Fp4 => crate::metal_kernels::call_dequant_bnb_fp4(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
self.out_ty.into(),
input_s.buffer(),
absmax_s.buffer(),
code_s.buffer(),
&output,
self.blocksize,
self.n,
)
.map_err(candle_core::Error::wrap)?,
BnbQuantType::Int8 => crate::metal_kernels::call_dequant_bnb_int8(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
self.out_ty.into(),
input_s.buffer(),
absmax_s.buffer(),
code_s.buffer(),
&output,
self.blocksize,
self.n,
)
.map_err(candle_core::Error::wrap)?,
};
let newstorage = candle_core::MetalStorage::new(
output,
device.clone(),
self.shape.elem_count(),
self.out_ty.into(),
);
Ok((newstorage, self.shape.clone()))
}
}
pub fn dequantize(
input: &Tensor,
absmax: &Tensor,
code: &Tensor,
shape: Shape,
blocksize: usize,
quant_ty: BnbQuantType,
out_ty: BnbDType,
) -> Result<Tensor> {
input.apply_op3(
absmax,
code,
DequantizeOp {
n: input.elem_count(),
blocksize,
shape,
quant_ty,
out_ty,
},
)
}