#![allow(clippy::excessive_precision)]
use std::fmt::Debug;
#[cfg(feature = "cuda")]
use diffusion_rs_common::core::cuda::{
cudarc::driver::{sys::CUstream, CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits},
CudaDevice,
};
use diffusion_rs_common::core::{
backend::BackendStorage, CpuStorage, CustomOp2, CustomOp3, DType, Result, Shape, Tensor,
WithDType,
};
#[cfg(feature = "cuda")]
use crate::bitsandbytes::ffi;
use super::{BnbDType, BnbQuantType};
struct DequantizeOp {
n: usize,
blocksize: usize,
shape: Shape,
quant_ty: BnbQuantType,
out_ty: BnbDType,
}
fn d_dequantize_nf4(val: u8) -> f32 {
if (val & 0b1000) == 0b1000 {
if (val & 0b0100) == 0b0100 {
if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
1.0
} else {
0.7229568362236023
}
} else if (val & 0b0001) == 0b0001 {
0.5626170039176941
} else {
0.44070982933044434
}
} else if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
0.33791524171829224
} else {
0.24611230194568634
}
} else if (val & 0b0001) == 0b0001 {
0.16093020141124725
} else {
0.07958029955625534
}
} else if (val & 0b0100) == 0b0100 {
if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
0.0
} else {
-0.09105003625154495
}
} else if (val & 0b0001) == 0b0001 {
-0.18477343022823334
} else {
-0.28444138169288635
}
} else if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
-0.39491748809814453
} else {
-0.5250730514526367
}
} else if (val & 0b0001) == 0b0001 {
-0.6961928009986877
} else {
-1.0
}
}
fn d_dequantize_fp4_tree(val: u8, absmax: f32) -> f32 {
let sign = if (val & 0b1000) == 0b1000 { -1.0 } else { 1.0 };
if (val & 0b0100) == 0b0100 {
if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
0.25000000 * absmax * sign } else {
0.16666667 * absmax * sign }
} else if (val & 0b0001) == 0b0001 {
0.50000000 * absmax * sign } else {
0.33333333 * absmax * sign }
} else if (val & 0b0010) == 0b0010 {
if (val & 0b0001) == 0b0001 {
1.00000000 * absmax * sign } else {
0.66666667 * absmax * sign }
} else if (val & 0b0001) == 0b0001 {
5.208333333e-03 * absmax * sign } else {
0.00000000 * absmax * sign }
}
impl DequantizeOp {
fn dequantize_cpu<T: WithDType + Debug>(
&self,
input: &[u8],
absmax: &[f32],
code: &[f32],
quant_ty: BnbQuantType,
) -> Vec<T> {
match quant_ty {
BnbQuantType::Int8 => {
let mut out = vec![T::zero(); self.n];
for block_idx in (0..self.n).step_by(self.blocksize) {
let valid_items = if self.n - block_idx >= self.blocksize {
self.blocksize
} else {
self.n - block_idx
};
let block_end = block_idx + valid_items;
for i in block_idx..block_end {
out[i] = T::from_f64(
(code[input[i] as usize] * absmax[block_idx / self.blocksize]) as f64,
);
}
}
out
}
BnbQuantType::Fp4 => {
let mut out = vec![T::zero(); self.shape.elem_count()];
for block_idx in (0..self.n).step_by(self.blocksize) {
let valid_items = if self.n > self.blocksize + block_idx {
self.blocksize
} else {
self.n - block_idx
};
let block_end = block_idx + valid_items;
let local_abs_max = absmax[block_idx / self.blocksize];
for i in block_idx..block_end {
out[i * 2] =
T::from_f64(d_dequantize_fp4_tree(input[i] >> 4, local_abs_max) as f64);
out[i * 2 + 1] = T::from_f64(d_dequantize_fp4_tree(
input[i] & 0x0F,
local_abs_max,
) as f64);
}
}
out
}
BnbQuantType::Nf4 => {
let mut out = vec![T::zero(); self.shape.elem_count()];
for block_idx in (0..self.n).step_by(self.blocksize) {
let valid_items = if self.n > self.blocksize + block_idx {
self.blocksize
} else {
self.n - block_idx
};
let block_end = block_idx + valid_items;
let local_abs_max = absmax[block_idx / (self.blocksize / 2)];
for i in block_idx..block_end {
out[i * 2] =
T::from_f64((d_dequantize_nf4(input[i] >> 4) * local_abs_max) as f64);
out[i * 2 + 1] =
T::from_f64((d_dequantize_nf4(input[i] & 0x0F) * local_abs_max) as f64);
}
}
out
}
}
}
#[cfg(feature = "cuda")]
fn dispatch_cuda_kernel<T: WithDType + DeviceRepr + ValidAsZeroBits>(
&self,
input: CudaView<u8>,
code: CudaView<f32>,
absmax: CudaView<f32>,
dev: &CudaDevice,
kernel: unsafe extern "C" fn(*const f32, *const u8, *const f32, *mut T, i32, i32, CUstream),
) -> Result<CudaSlice<T>> {
use diffusion_rs_common::core::cuda::{cudarc::driver::DevicePtr, WrapErr};
let out = unsafe { dev.alloc::<T>(self.shape.elem_count()).w()? };
unsafe {
kernel(
(*code.device_ptr()) as *const _,
(*input.device_ptr()) as *const _,
(*absmax.device_ptr()) as *const _,
(*out.device_ptr()) as *mut _,
self.blocksize as i32,
self.shape.elem_count() as i32,
*dev.cu_stream(),
)
};
Ok(out)
}
}
impl CustomOp3 for DequantizeOp {
fn name(&self) -> &'static str {
"dequantize-bnb"
}
fn cpu_fwd(
&self,
input_s: &CpuStorage,
input_l: &diffusion_rs_common::core::Layout,
absmax_s: &CpuStorage,
absmax_l: &diffusion_rs_common::core::Layout,
code_s: &CpuStorage,
code_l: &diffusion_rs_common::core::Layout,
) -> diffusion_rs_common::core::Result<(CpuStorage, diffusion_rs_common::core::Shape)> {
if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
diffusion_rs_common::bail!("All inputs must be contiguous");
}
match (input_s, absmax_s, code_s, self.out_ty) {
(
CpuStorage::U8(input),
CpuStorage::F32(absmax),
CpuStorage::F32(code),
BnbDType::BF16,
) => Ok((
CpuStorage::BF16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
self.shape.clone(),
)),
(
CpuStorage::U8(input),
CpuStorage::F32(absmax),
CpuStorage::F32(code),
BnbDType::F16,
) => Ok((
CpuStorage::F16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
self.shape.clone(),
)),
(
CpuStorage::U8(input),
CpuStorage::F32(absmax),
CpuStorage::F32(code),
BnbDType::F32,
) => Ok((
CpuStorage::F32(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
self.shape.clone(),
)),
(i, a, c, t) => diffusion_rs_common::bail!(
"Unsupported dtypes for cpu dequant: {:?} input, {:?} absmax, {:?} code, {:?} out",
i.dtype(),
a.dtype(),
c.dtype(),
t
),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
input_s: &diffusion_rs_common::core::CudaStorage,
input_l: &diffusion_rs_common::core::Layout,
absmax_s: &diffusion_rs_common::core::CudaStorage,
absmax_l: &diffusion_rs_common::core::Layout,
code_s: &diffusion_rs_common::core::CudaStorage,
code_l: &diffusion_rs_common::core::Layout,
) -> Result<(diffusion_rs_common::core::CudaStorage, Shape)> {
if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
diffusion_rs_common::bail!("All inputs must be contiguous");
}
let input_slice = input_s
.as_cuda_slice::<u8>()?
.slice(input_l.start_offset()..);
let absmax_slice = absmax_s
.as_cuda_slice::<f32>()?
.slice(absmax_l.start_offset()..);
let code_slice = code_s
.as_cuda_slice::<f32>()?
.slice(code_l.start_offset()..);
let dev = input_s.device().clone();
let out = match (self.out_ty, self.quant_ty) {
(BnbDType::F32, BnbQuantType::Nf4) => {
diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<f32>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f32_nf4,
)?,
dev,
)
}
(BnbDType::F16, BnbQuantType::Nf4) => {
diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::f16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f16_nf4,
)?,
dev,
)
}
(BnbDType::BF16, BnbQuantType::Nf4) => {
diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::bf16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_bf16_nf4,
)?,
dev,
)
}
(BnbDType::F32, BnbQuantType::Fp4) => {
diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<f32>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f32_fp4,
)?,
dev,
)
}
(BnbDType::F16, BnbQuantType::Fp4) => {
diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::f16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f16_fp4,
)?,
dev,
)
}
(BnbDType::BF16, BnbQuantType::Fp4) => {
diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::bf16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_bf16_fp4,
)?,
dev,
)
}
(BnbDType::F32, BnbQuantType::Int8) => {
diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<f32>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f32_int8,
)?,
dev,
)
}
(BnbDType::F16, BnbQuantType::Int8) => {
diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::f16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_f16_int8,
)?,
dev,
)
}
(BnbDType::BF16, BnbQuantType::Int8) => {
diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel::<half::bf16>(
input_slice,
code_slice,
absmax_slice,
&dev,
ffi::dequantize_blockwise_bf16_int8,
)?,
dev,
)
}
};
Ok((out, self.shape.clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
input_s: &diffusion_rs_common::core::MetalStorage,
input_l: &diffusion_rs_common::core::Layout,
absmax_s: &diffusion_rs_common::core::MetalStorage,
absmax_l: &diffusion_rs_common::core::Layout,
code_s: &diffusion_rs_common::core::MetalStorage,
code_l: &diffusion_rs_common::core::Layout,
) -> Result<(diffusion_rs_common::core::MetalStorage, Shape)> {
use diffusion_rs_common::core::DType;
if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
diffusion_rs_common::bail!("All inputs must be contiguous");
}
let command_buffer = input_s.device().command_buffer()?;
command_buffer.set_label("dequant-bnb-nf4");
let device = input_s.device();
let output = device.new_buffer(
self.shape.elem_count(),
self.out_ty.into(),
"dequant-bnb-nf4",
)?;
if input_s.dtype() != DType::U8 {
diffusion_rs_common::bail!("input must be u8");
}
if code_s.dtype() != DType::F32 {
diffusion_rs_common::bail!("code must be f32");
}
if absmax_s.dtype() != DType::F32 {
diffusion_rs_common::bail!("absmax must be f32");
}
match self.quant_ty {
BnbQuantType::Nf4 => crate::metal_kernels::call_dequant_bnb_nf4(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
self.out_ty.into(),
input_s.buffer(),
input_l.start_offset() * input_s.dtype().size_in_bytes(),
absmax_s.buffer(),
absmax_l.start_offset() * absmax_s.dtype().size_in_bytes(),
code_s.buffer(),
code_l.start_offset() * code_s.dtype().size_in_bytes(),
&output,
self.blocksize,
self.n,
)
.map_err(diffusion_rs_common::core::Error::wrap)?,
BnbQuantType::Fp4 => crate::metal_kernels::call_dequant_bnb_fp4(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
self.out_ty.into(),
input_s.buffer(),
input_l.start_offset() * input_s.dtype().size_in_bytes(),
absmax_s.buffer(),
absmax_l.start_offset() * absmax_s.dtype().size_in_bytes(),
code_s.buffer(),
code_l.start_offset() * code_s.dtype().size_in_bytes(),
&output,
self.blocksize,
self.n,
)
.map_err(diffusion_rs_common::core::Error::wrap)?,
BnbQuantType::Int8 => crate::metal_kernels::call_dequant_bnb_int8(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
self.out_ty.into(),
input_s.buffer(),
input_l.start_offset() * input_s.dtype().size_in_bytes(),
absmax_s.buffer(),
absmax_l.start_offset() * absmax_s.dtype().size_in_bytes(),
code_s.buffer(),
code_l.start_offset() * code_s.dtype().size_in_bytes(),
&output,
self.blocksize,
self.n,
)
.map_err(diffusion_rs_common::core::Error::wrap)?,
};
let newstorage = diffusion_rs_common::core::MetalStorage::new(
output,
device.clone(),
self.shape.elem_count(),
self.out_ty.into(),
);
Ok((newstorage, self.shape.clone()))
}
}
pub fn dequantize(
input: &Tensor,
absmax: &Tensor,
code: &Tensor,
shape: Shape,
blocksize: usize,
quant_ty: BnbQuantType,
out_ty: BnbDType,
) -> Result<Tensor> {
input.apply_op3(
absmax,
code,
DequantizeOp {
n: input.elem_count(),
blocksize,
shape,
quant_ty,
out_ty,
},
)
}
struct Dequantize8BitOp {
out_ty: DType,
}
impl Dequantize8BitOp {
fn dequantize_cpu<T: WithDType + Debug>(
&self,
weight: &[i8],
scb: &[f32],
col: usize,
) -> Vec<T> {
let mut out = vec![T::zero(); weight.len()];
for (i, w) in weight.iter().enumerate() {
let local_scb = scb[i / col];
out[i] = T::from_f64((*w as f64 * local_scb as f64) / 127.);
}
out
}
#[cfg(feature = "cuda")]
fn dispatch_cuda_kernel<T: WithDType + DeviceRepr + ValidAsZeroBits>(
&self,
weight: CudaView<i8>,
scb: CudaView<f32>,
row: i32,
col: i32,
n: i32,
dev: &CudaDevice,
kernel: unsafe extern "C" fn(*const i8, *const f32, *mut T, i32, i32, i32),
) -> Result<CudaSlice<T>> {
use diffusion_rs_common::core::cuda::{cudarc::driver::DevicePtr, WrapErr};
let out = unsafe { dev.alloc::<T>(n as usize).w()? };
unsafe {
kernel(
(*weight.device_ptr()) as *const _,
(*scb.device_ptr()) as *const _,
(*out.device_ptr()) as *mut _,
row,
col,
n,
)
};
Ok(out)
}
}
impl CustomOp2 for Dequantize8BitOp {
fn name(&self) -> &'static str {
"dequantize-8bit-bnb"
}
fn cpu_fwd(
&self,
weight_s: &CpuStorage,
weight_l: &diffusion_rs_common::core::Layout,
scb_s: &CpuStorage,
scb_l: &diffusion_rs_common::core::Layout,
) -> diffusion_rs_common::core::Result<(CpuStorage, diffusion_rs_common::core::Shape)> {
if !(weight_l.is_contiguous() && scb_l.is_contiguous()) {
diffusion_rs_common::bail!("All inputs must be contiguous");
}
let row = weight_l.dim(0)?;
let col = weight_l.dim(1)?;
if row != scb_l.dim(0)? {
diffusion_rs_common::bail!("scb dim0 must match weight dim0");
}
match (weight_s, scb_s, self.out_ty) {
(CpuStorage::I8(weight), CpuStorage::F32(scb), DType::BF16) => Ok((
CpuStorage::BF16(self.dequantize_cpu(weight, scb, col)),
weight_l.shape().clone(),
)),
(CpuStorage::I8(weight), CpuStorage::F32(scb), DType::F16) => Ok((
CpuStorage::F16(self.dequantize_cpu(weight, scb, col)),
weight_l.shape().clone(),
)),
(CpuStorage::I8(weight), CpuStorage::F32(scb), DType::F32) => Ok((
CpuStorage::F32(self.dequantize_cpu(weight, scb, col)),
weight_l.shape().clone(),
)),
(w, s, t) => diffusion_rs_common::bail!(
"Unsupported dtypes for cpu dequant: {:?} weight, {:?} scb, {:?} out",
w.dtype(),
s.dtype(),
t
),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
weight_s: &diffusion_rs_common::core::CudaStorage,
weight_l: &diffusion_rs_common::core::Layout,
scb_s: &diffusion_rs_common::core::CudaStorage,
scb_l: &diffusion_rs_common::core::Layout,
) -> Result<(diffusion_rs_common::core::CudaStorage, Shape)> {
if !(weight_l.is_contiguous() && scb_l.is_contiguous()) {
diffusion_rs_common::bail!("All inputs must be contiguous");
}
let weight_slice = weight_s
.as_cuda_slice::<i8>()?
.slice(weight_l.start_offset()..);
let scb_slice = scb_s.as_cuda_slice::<f32>()?.slice(scb_l.start_offset()..);
let dev = weight_s.device().clone();
let row = weight_l.dim(0)? as i32;
let col = weight_l.dim(1)? as i32;
let n = weight_l.shape().elem_count() as i32;
if row != scb_l.dim(0)? as i32 {
diffusion_rs_common::bail!("scb dim0 must match weight dim0");
}
let out = match self.out_ty {
DType::F32 => diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel(
weight_slice,
scb_slice,
row,
col,
n,
&dev,
ffi::dequantize_8bit_kernel_f32,
)?,
dev,
),
DType::F16 => diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel(
weight_slice,
scb_slice,
row,
col,
n,
&dev,
ffi::dequantize_8bit_kernel_f16,
)?,
dev,
),
DType::BF16 => diffusion_rs_common::core::CudaStorage::wrap_cuda_slice(
self.dispatch_cuda_kernel(
weight_slice,
scb_slice,
row,
col,
n,
&dev,
ffi::dequantize_8bit_kernel_bf16,
)?,
dev,
),
_ => diffusion_rs_common::bail!("only f32/bf16/f16 are allowed in dequantize-8bit-op"),
};
Ok((out, weight_l.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
weight_s: &diffusion_rs_common::core::MetalStorage,
weight_l: &diffusion_rs_common::core::Layout,
scb_s: &diffusion_rs_common::core::MetalStorage,
scb_l: &diffusion_rs_common::core::Layout,
) -> Result<(diffusion_rs_common::core::MetalStorage, Shape)> {
use diffusion_rs_common::core::DType;
if !(weight_l.is_contiguous() && scb_l.is_contiguous()) {
diffusion_rs_common::bail!("All inputs must be contiguous");
}
let command_buffer = weight_s.device().command_buffer()?;
command_buffer.set_label("dequant-bnb-nf4");
let device = weight_s.device();
let row = weight_l.dim(0)?;
let col = weight_l.dim(1)?;
let n = weight_l.shape().elem_count();
let output = device.new_buffer(n, self.out_ty, "dequant-8bit-bnb")?;
if weight_s.dtype() != DType::I8 {
diffusion_rs_common::bail!("input must be i8");
}
if scb_s.dtype() != DType::F32 {
diffusion_rs_common::bail!("scb must be f32");
}
if row != scb_l.dim(0)? {
diffusion_rs_common::bail!("scb dim0 must match weight dim0");
}
crate::metal_kernels::call_dequant_bnb_8bit(
device.device(),
&command_buffer,
&crate::metal_kernels::Kernels::new(),
self.out_ty,
weight_s.buffer(),
weight_l.start_offset() * weight_s.dtype().size_in_bytes(),
scb_s.buffer(),
scb_l.start_offset() * scb_s.dtype().size_in_bytes(),
&output,
row,
col,
n,
)
.map_err(diffusion_rs_common::core::Error::wrap)?;
let newstorage =
diffusion_rs_common::core::MetalStorage::new(output, device.clone(), n, self.out_ty);
Ok((newstorage, weight_l.shape().clone()))
}
}
pub fn dequantize_8bit(weight: &Tensor, scb: &Tensor, out_ty: DType) -> Result<Tensor> {
weight.apply_op2(scb, Dequantize8BitOp { out_ty })
}