1#![allow(clippy::excessive_precision)]
2
3use std::fmt::Debug;
4
5#[cfg(feature = "cuda")]
6use candle_core::cuda::{
7 cudarc::driver::{sys::CUstream, CudaSlice, DeviceRepr, ValidAsZeroBits},
8 CudaDevice,
9};
10
11use candle_core::{
12 backend::BackendStorage, CpuStorage, CustomOp3, Result, Shape, Tensor, WithDType,
13};
14
15#[cfg(feature = "cuda")]
16use crate::bitsandbytes::ffi;
17
18use super::{BnbDType, BnbQuantType};
19
20struct DequantizeOp {
21 n: usize,
22 blocksize: usize,
23 shape: Shape,
24 quant_ty: BnbQuantType,
25 out_ty: BnbDType,
26}
27
28fn d_dequantize_nf4(val: u8) -> f32 {
29 if (val & 0b1000) == 0b1000 {
32 if (val & 0b0100) == 0b0100 {
33 if (val & 0b0010) == 0b0010 {
35 if (val & 0b0001) == 0b0001 {
37 1.0
39 } else {
40 0.7229568362236023
41 }
42 } else if (val & 0b0001) == 0b0001 {
43 0.5626170039176941
45 } else {
46 0.44070982933044434
47 }
48 } else if (val & 0b0010) == 0b0010 {
49 if (val & 0b0001) == 0b0001 {
51 0.33791524171829224
53 } else {
54 0.24611230194568634
55 }
56 } else if (val & 0b0001) == 0b0001 {
57 0.16093020141124725
59 } else {
60 0.07958029955625534
61 }
62 } else if (val & 0b0100) == 0b0100 {
63 if (val & 0b0010) == 0b0010 {
65 if (val & 0b0001) == 0b0001 {
67 0.0
69 } else {
70 -0.09105003625154495
71 }
72 } else if (val & 0b0001) == 0b0001 {
73 -0.18477343022823334
75 } else {
76 -0.28444138169288635
77 }
78 } else if (val & 0b0010) == 0b0010 {
79 if (val & 0b0001) == 0b0001 {
81 -0.39491748809814453
83 } else {
84 -0.5250730514526367
85 }
86 } else if (val & 0b0001) == 0b0001 {
87 -0.6961928009986877
89 } else {
90 -1.0
91 }
92}
93
94fn d_dequantize_fp4_tree(val: u8, absmax: f32) -> f32 {
95 let sign = if (val & 0b1000) == 0b1000 { -1.0 } else { 1.0 };
96
97 if (val & 0b0100) == 0b0100 {
98 if (val & 0b0010) == 0b0010 {
100 if (val & 0b0001) == 0b0001 {
102 0.25000000 * absmax * sign } else {
105 0.16666667 * absmax * sign }
107 } else if (val & 0b0001) == 0b0001 {
108 0.50000000 * absmax * sign } else {
111 0.33333333 * absmax * sign }
113 } else if (val & 0b0010) == 0b0010 {
114 if (val & 0b0001) == 0b0001 {
116 1.00000000 * absmax * sign } else {
119 0.66666667 * absmax * sign }
121 } else if (val & 0b0001) == 0b0001 {
122 5.208333333e-03 * absmax * sign } else {
125 0.00000000 * absmax * sign }
127}
128
129impl DequantizeOp {
130 fn dequantize_cpu<T: WithDType + Debug>(
131 &self,
132 input: &[u8],
133 absmax: &[f32],
134 code: &[f32],
135 quant_ty: BnbQuantType,
136 ) -> Vec<T> {
137 match quant_ty {
138 BnbQuantType::Int8 => {
139 let mut out = vec![T::zero(); self.n];
140 for block_idx in (0..self.n).step_by(self.blocksize) {
141 let valid_items = if self.n - block_idx >= self.blocksize {
142 self.blocksize
143 } else {
144 self.n - block_idx
145 };
146 let block_end = block_idx + valid_items;
147 for i in block_idx..block_end {
148 out[i] = T::from_f64(
149 (code[input[i] as usize] * absmax[block_idx / self.blocksize]) as f64,
150 );
151 }
152 }
153 out
154 }
155 BnbQuantType::Fp4 => {
156 let mut out = vec![T::zero(); self.shape.elem_count()];
157 for block_idx in (0..self.n).step_by(self.blocksize) {
158 let valid_items = if self.n > self.blocksize + block_idx {
159 self.blocksize
160 } else {
161 self.n - block_idx
162 };
163 let block_end = block_idx + valid_items;
164
165 let local_abs_max = absmax[block_idx / self.blocksize];
166
167 for i in block_idx..block_end {
168 out[i * 2] =
169 T::from_f64(d_dequantize_fp4_tree(input[i] >> 4, local_abs_max) as f64);
170 out[i * 2 + 1] = T::from_f64(d_dequantize_fp4_tree(
171 input[i] & 0x0F,
172 local_abs_max,
173 ) as f64);
174 }
175 }
176 out
177 }
178 BnbQuantType::Nf4 => {
179 let mut out = vec![T::zero(); self.shape.elem_count()];
180 for block_idx in (0..self.n).step_by(self.blocksize) {
181 let valid_items = if self.n > self.blocksize + block_idx {
182 self.blocksize
183 } else {
184 self.n - block_idx
185 };
186 let block_end = block_idx + valid_items;
187
188 let local_abs_max = absmax[block_idx / (self.blocksize / 2)];
189
190 for i in block_idx..block_end {
191 out[i * 2] =
192 T::from_f64((d_dequantize_nf4(input[i] >> 4) * local_abs_max) as f64);
193 out[i * 2 + 1] =
194 T::from_f64((d_dequantize_nf4(input[i] & 0x0F) * local_abs_max) as f64);
195 }
196 }
197 out
198 }
199 }
200 }
201
202 #[cfg(feature = "cuda")]
203 fn dispatch_cuda_kernel<T: WithDType + DeviceRepr + ValidAsZeroBits>(
204 &self,
205 input: &CudaSlice<u8>,
206 code: &CudaSlice<f32>,
207 absmax: &CudaSlice<f32>,
208 dev: &CudaDevice,
209 kernel: unsafe extern "C" fn(*const f32, *const u8, *const f32, *mut T, i32, i32, CUstream),
210 ) -> Result<CudaSlice<T>> {
211 use candle_core::cuda::{cudarc::driver::DevicePtr, WrapErr};
212
213 let out = unsafe { dev.alloc::<T>(self.shape.elem_count()).w()? };
214 unsafe {
215 kernel(
216 (*code.device_ptr()) as *const _,
217 (*input.device_ptr()) as *const _,
218 (*absmax.device_ptr()) as *const _,
219 (*out.device_ptr()) as *mut _,
220 self.blocksize as i32,
221 self.shape.elem_count() as i32,
222 *dev.cu_stream(),
223 )
224 };
225
226 Ok(out)
227 }
228}
229
230impl CustomOp3 for DequantizeOp {
231 fn name(&self) -> &'static str {
232 "dequantize-bnb"
233 }
234
235 fn cpu_fwd(
236 &self,
237 input_s: &CpuStorage,
238 input_l: &candle_core::Layout,
239 absmax_s: &CpuStorage,
240 absmax_l: &candle_core::Layout,
241 code_s: &CpuStorage,
242 code_l: &candle_core::Layout,
243 ) -> candle_core::Result<(CpuStorage, candle_core::Shape)> {
244 if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
245 candle_core::bail!("All inputs must be contiguous");
246 }
247 match (input_s, absmax_s, code_s, self.out_ty) {
248 (
249 CpuStorage::U8(input),
250 CpuStorage::F32(absmax),
251 CpuStorage::F32(code),
252 BnbDType::BF16,
253 ) => Ok((
254 CpuStorage::BF16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
255 self.shape.clone(),
256 )),
257 (
258 CpuStorage::U8(input),
259 CpuStorage::F32(absmax),
260 CpuStorage::F32(code),
261 BnbDType::F16,
262 ) => Ok((
263 CpuStorage::F16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
264 self.shape.clone(),
265 )),
266 (
267 CpuStorage::U8(input),
268 CpuStorage::F32(absmax),
269 CpuStorage::F32(code),
270 BnbDType::F32,
271 ) => Ok((
272 CpuStorage::F32(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
273 self.shape.clone(),
274 )),
275 (i, a, c, t) => candle_core::bail!(
276 "Unsupported dtypes for cpu dequant: {:?} input, {:?} absmax, {:?} code, {:?} out",
277 i.dtype(),
278 a.dtype(),
279 c.dtype(),
280 t
281 ),
282 }
283 }
284
285 #[cfg(feature = "cuda")]
286 fn cuda_fwd(
287 &self,
288 input_s: &candle_core::CudaStorage,
289 input_l: &candle_core::Layout,
290 absmax_s: &candle_core::CudaStorage,
291 absmax_l: &candle_core::Layout,
292 code_s: &candle_core::CudaStorage,
293 code_l: &candle_core::Layout,
294 ) -> Result<(candle_core::CudaStorage, Shape)> {
295 if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
296 candle_core::bail!("All inputs must be contiguous");
297 }
298 let input_slice = input_s.as_cuda_slice::<u8>()?;
299 let absmax_slice = absmax_s.as_cuda_slice::<f32>()?;
300 let code_slice = code_s.as_cuda_slice::<f32>()?;
301 let dev = input_s.device().clone();
302 let out = match (self.out_ty, self.quant_ty) {
303 (BnbDType::F32, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
304 self.dispatch_cuda_kernel::<f32>(
305 input_slice,
306 code_slice,
307 absmax_slice,
308 &dev,
309 ffi::dequantize_blockwise_f32_nf4,
310 )?,
311 dev,
312 ),
313 (BnbDType::F16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
314 self.dispatch_cuda_kernel::<half::f16>(
315 input_slice,
316 code_slice,
317 absmax_slice,
318 &dev,
319 ffi::dequantize_blockwise_f16_nf4,
320 )?,
321 dev,
322 ),
323 (BnbDType::BF16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
324 self.dispatch_cuda_kernel::<half::bf16>(
325 input_slice,
326 code_slice,
327 absmax_slice,
328 &dev,
329 ffi::dequantize_blockwise_bf16_nf4,
330 )?,
331 dev,
332 ),
333
334 (BnbDType::F32, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
335 self.dispatch_cuda_kernel::<f32>(
336 input_slice,
337 code_slice,
338 absmax_slice,
339 &dev,
340 ffi::dequantize_blockwise_f32_fp4,
341 )?,
342 dev,
343 ),
344 (BnbDType::F16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
345 self.dispatch_cuda_kernel::<half::f16>(
346 input_slice,
347 code_slice,
348 absmax_slice,
349 &dev,
350 ffi::dequantize_blockwise_f16_fp4,
351 )?,
352 dev,
353 ),
354 (BnbDType::BF16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
355 self.dispatch_cuda_kernel::<half::bf16>(
356 input_slice,
357 code_slice,
358 absmax_slice,
359 &dev,
360 ffi::dequantize_blockwise_bf16_fp4,
361 )?,
362 dev,
363 ),
364
365 (BnbDType::F32, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
366 self.dispatch_cuda_kernel::<f32>(
367 input_slice,
368 code_slice,
369 absmax_slice,
370 &dev,
371 ffi::dequantize_blockwise_f32_int8,
372 )?,
373 dev,
374 ),
375 (BnbDType::F16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
376 self.dispatch_cuda_kernel::<half::f16>(
377 input_slice,
378 code_slice,
379 absmax_slice,
380 &dev,
381 ffi::dequantize_blockwise_f16_int8,
382 )?,
383 dev,
384 ),
385 (BnbDType::BF16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
386 self.dispatch_cuda_kernel::<half::bf16>(
387 input_slice,
388 code_slice,
389 absmax_slice,
390 &dev,
391 ffi::dequantize_blockwise_bf16_int8,
392 )?,
393 dev,
394 ),
395 };
396
397 Ok((out, self.shape.clone()))
398 }
399
400 #[cfg(feature = "metal")]
401 fn metal_fwd(
402 &self,
403 input_s: &candle_core::MetalStorage,
404 input_l: &candle_core::Layout,
405 absmax_s: &candle_core::MetalStorage,
406 absmax_l: &candle_core::Layout,
407 code_s: &candle_core::MetalStorage,
408 code_l: &candle_core::Layout,
409 ) -> Result<(candle_core::MetalStorage, Shape)> {
410 use candle_core::DType;
411
412 if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
413 candle_core::bail!("All inputs must be contiguous");
414 }
415
416 let command_buffer = input_s.device().command_buffer()?;
417 command_buffer.set_label("dequant-bnb-nf4");
418
419 let device = input_s.device();
420
421 let output = device.new_buffer(
422 self.shape.elem_count(),
423 self.out_ty.into(),
424 "dequant-bnb-nf4",
425 )?;
426
427 if input_s.dtype() != DType::U8 {
428 candle_core::bail!("input must be u8");
429 }
430 if code_s.dtype() != DType::F32 {
431 candle_core::bail!("code must be f32");
432 }
433 if absmax_s.dtype() != DType::F32 {
434 candle_core::bail!("absmax must be f32");
435 }
436
437 match self.quant_ty {
438 BnbQuantType::Nf4 => crate::metal_kernels::call_dequant_bnb_nf4(
439 device.device(),
440 &command_buffer,
441 &crate::metal_kernels::Kernels::new(),
442 self.out_ty.into(),
443 input_s.buffer(),
444 absmax_s.buffer(),
445 code_s.buffer(),
446 &output,
447 self.blocksize,
448 self.n,
449 )
450 .map_err(candle_core::Error::wrap)?,
451 BnbQuantType::Fp4 => crate::metal_kernels::call_dequant_bnb_fp4(
452 device.device(),
453 &command_buffer,
454 &crate::metal_kernels::Kernels::new(),
455 self.out_ty.into(),
456 input_s.buffer(),
457 absmax_s.buffer(),
458 code_s.buffer(),
459 &output,
460 self.blocksize,
461 self.n,
462 )
463 .map_err(candle_core::Error::wrap)?,
464 BnbQuantType::Int8 => crate::metal_kernels::call_dequant_bnb_int8(
465 device.device(),
466 &command_buffer,
467 &crate::metal_kernels::Kernels::new(),
468 self.out_ty.into(),
469 input_s.buffer(),
470 absmax_s.buffer(),
471 code_s.buffer(),
472 &output,
473 self.blocksize,
474 self.n,
475 )
476 .map_err(candle_core::Error::wrap)?,
477 };
478
479 let newstorage = candle_core::MetalStorage::new(
480 output,
481 device.clone(),
482 self.shape.elem_count(),
483 self.out_ty.into(),
484 );
485 Ok((newstorage, self.shape.clone()))
486 }
487}
488
489pub fn dequantize(
490 input: &Tensor,
491 absmax: &Tensor,
492 code: &Tensor,
493 shape: Shape,
494 blocksize: usize,
495 quant_ty: BnbQuantType,
496 out_ty: BnbDType,
497) -> Result<Tensor> {
498 input.apply_op3(
499 absmax,
500 code,
501 DequantizeOp {
502 n: input.elem_count(),
503 blocksize,
504 shape,
505 quant_ty,
506 out_ty,
507 },
508 )
509}