1#![allow(clippy::excessive_precision)]
2
3use std::fmt::Debug;
4
5#[cfg(feature = "cuda")]
6use candle_core::cuda::{
7 cudarc::driver::{sys::CUstream, CudaSlice, DeviceRepr, ValidAsZeroBits},
8 CudaDevice,
9};
10
11use candle_core::{
12 backend::BackendStorage, CpuStorage, CustomOp3, Result, Shape, Tensor, WithDType,
13};
14
15#[cfg(feature = "cuda")]
16use crate::bitsandbytes::ffi;
17
18use super::{BnbDType, BnbQuantType};
19
20struct DequantizeOp {
21 n: usize,
22 blocksize: usize,
23 shape: Shape,
24 quant_ty: BnbQuantType,
25 out_ty: BnbDType,
26}
27
28fn d_dequantize_nf4(val: u8) -> f32 {
29 if (val & 0b1000) == 0b1000 {
32 if (val & 0b0100) == 0b0100 {
33 if (val & 0b0010) == 0b0010 {
35 if (val & 0b0001) == 0b0001 {
37 1.0
39 } else {
40 0.7229568362236023
41 }
42 } else if (val & 0b0001) == 0b0001 {
43 0.5626170039176941
45 } else {
46 0.44070982933044434
47 }
48 } else if (val & 0b0010) == 0b0010 {
49 if (val & 0b0001) == 0b0001 {
51 0.33791524171829224
53 } else {
54 0.24611230194568634
55 }
56 } else if (val & 0b0001) == 0b0001 {
57 0.16093020141124725
59 } else {
60 0.07958029955625534
61 }
62 } else if (val & 0b0100) == 0b0100 {
63 if (val & 0b0010) == 0b0010 {
65 if (val & 0b0001) == 0b0001 {
67 0.0
69 } else {
70 -0.09105003625154495
71 }
72 } else if (val & 0b0001) == 0b0001 {
73 -0.18477343022823334
75 } else {
76 -0.28444138169288635
77 }
78 } else if (val & 0b0010) == 0b0010 {
79 if (val & 0b0001) == 0b0001 {
81 -0.39491748809814453
83 } else {
84 -0.5250730514526367
85 }
86 } else if (val & 0b0001) == 0b0001 {
87 -0.6961928009986877
89 } else {
90 -1.0
91 }
92}
93
94fn d_dequantize_fp4_tree(val: u8, absmax: f32) -> f32 {
95 let sign = if (val & 0b1000) == 0b1000 { -1.0 } else { 1.0 };
96
97 if (val & 0b0100) == 0b0100 {
98 if (val & 0b0010) == 0b0010 {
100 if (val & 0b0001) == 0b0001 {
102 0.25000000 * absmax * sign } else {
105 0.16666667 * absmax * sign }
107 } else if (val & 0b0001) == 0b0001 {
108 0.50000000 * absmax * sign } else {
111 0.33333333 * absmax * sign }
113 } else if (val & 0b0010) == 0b0010 {
114 if (val & 0b0001) == 0b0001 {
116 1.00000000 * absmax * sign } else {
119 0.66666667 * absmax * sign }
121 } else if (val & 0b0001) == 0b0001 {
122 5.208333333e-03 * absmax * sign } else {
125 0.00000000 * absmax * sign }
127}
128
129impl DequantizeOp {
130 fn dequantize_cpu<T: WithDType + Debug>(
131 &self,
132 input: &[u8],
133 absmax: &[f32],
134 code: &[f32],
135 quant_ty: BnbQuantType,
136 ) -> Vec<T> {
137 match quant_ty {
138 BnbQuantType::Int8 => {
139 let mut out = vec![T::zero(); self.n];
140 for block_idx in (0..self.n).step_by(self.blocksize) {
141 let valid_items = if self.n - block_idx >= self.blocksize {
142 self.blocksize
143 } else {
144 self.n - block_idx
145 };
146 let block_end = block_idx + valid_items;
147 for i in block_idx..block_end {
148 out[i] = T::from_f64(
149 (code[input[i] as usize] * absmax[block_idx / self.blocksize]) as f64,
150 );
151 }
152 }
153 out
154 }
155 BnbQuantType::Fp4 => {
156 let mut out = vec![T::zero(); self.shape.elem_count()];
157 for block_idx in (0..self.n).step_by(self.blocksize) {
158 let valid_items = if self.n > self.blocksize + block_idx {
159 self.blocksize
160 } else {
161 self.n - block_idx
162 };
163 let block_end = block_idx + valid_items;
164
165 let local_abs_max = absmax[block_idx / self.blocksize];
166
167 for i in block_idx..block_end {
168 out[i * 2] =
169 T::from_f64(d_dequantize_fp4_tree(input[i] >> 4, local_abs_max) as f64);
170 out[i * 2 + 1] = T::from_f64(d_dequantize_fp4_tree(
171 input[i] & 0x0F,
172 local_abs_max,
173 ) as f64);
174 }
175 }
176 out
177 }
178 BnbQuantType::Nf4 => {
179 let mut out = vec![T::zero(); self.shape.elem_count()];
180 for block_idx in (0..self.n).step_by(self.blocksize) {
181 let valid_items = if self.n > self.blocksize + block_idx {
182 self.blocksize
183 } else {
184 self.n - block_idx
185 };
186 let block_end = block_idx + valid_items;
187
188 let local_abs_max = absmax[block_idx / (self.blocksize / 2)];
189
190 for i in block_idx..block_end {
191 out[i * 2] =
192 T::from_f64((d_dequantize_nf4(input[i] >> 4) * local_abs_max) as f64);
193 out[i * 2 + 1] =
194 T::from_f64((d_dequantize_nf4(input[i] & 0x0F) * local_abs_max) as f64);
195 }
196 }
197 out
198 }
199 }
200 }
201
202 #[cfg(feature = "cuda")]
203 fn dispatch_cuda_kernel<T: WithDType + DeviceRepr + ValidAsZeroBits>(
204 &self,
205 input: &CudaSlice<u8>,
206 code: &CudaSlice<f32>,
207 absmax: &CudaSlice<f32>,
208 dev: &CudaDevice,
209 kernel: unsafe extern "C" fn(*const f32, *const u8, *const f32, *mut T, i32, i32, CUstream),
210 ) -> Result<CudaSlice<T>> {
211 use crate::utils::slice_ptr;
212
213 let out = unsafe { dev.alloc::<T>(self.shape.elem_count())? };
214
215 let (code, _code_guard) = slice_ptr(code, 0);
216 let (input, _input_guard) = slice_ptr(input, 0);
217 let (absmax, _absmax_guard) = slice_ptr(absmax, 0);
218 let (out_ptr, out_guard) = slice_ptr(&out, 0);
219
220 unsafe {
221 kernel(
222 code as *const _,
223 input as *const _,
224 absmax as *const _,
225 out_ptr as *mut _,
226 self.blocksize as i32,
227 self.shape.elem_count() as i32,
228 dev.cuda_stream().cu_stream(),
229 )
230 };
231
232 drop(out_guard);
233
234 Ok(out)
235 }
236}
237
238impl CustomOp3 for DequantizeOp {
239 fn name(&self) -> &'static str {
240 "dequantize-bnb"
241 }
242
243 fn cpu_fwd(
244 &self,
245 input_s: &CpuStorage,
246 input_l: &candle_core::Layout,
247 absmax_s: &CpuStorage,
248 absmax_l: &candle_core::Layout,
249 code_s: &CpuStorage,
250 code_l: &candle_core::Layout,
251 ) -> candle_core::Result<(CpuStorage, candle_core::Shape)> {
252 if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
253 candle_core::bail!("All inputs must be contiguous");
254 }
255 match (input_s, absmax_s, code_s, self.out_ty) {
256 (
257 CpuStorage::U8(input),
258 CpuStorage::F32(absmax),
259 CpuStorage::F32(code),
260 BnbDType::BF16,
261 ) => Ok((
262 CpuStorage::BF16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
263 self.shape.clone(),
264 )),
265 (
266 CpuStorage::U8(input),
267 CpuStorage::F32(absmax),
268 CpuStorage::F32(code),
269 BnbDType::F16,
270 ) => Ok((
271 CpuStorage::F16(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
272 self.shape.clone(),
273 )),
274 (
275 CpuStorage::U8(input),
276 CpuStorage::F32(absmax),
277 CpuStorage::F32(code),
278 BnbDType::F32,
279 ) => Ok((
280 CpuStorage::F32(self.dequantize_cpu(input, absmax, code, self.quant_ty)),
281 self.shape.clone(),
282 )),
283 (i, a, c, t) => candle_core::bail!(
284 "Unsupported dtypes for cpu dequant: {:?} input, {:?} absmax, {:?} code, {:?} out",
285 i.dtype(),
286 a.dtype(),
287 c.dtype(),
288 t
289 ),
290 }
291 }
292
293 #[cfg(feature = "cuda")]
294 fn cuda_fwd(
295 &self,
296 input_s: &candle_core::CudaStorage,
297 input_l: &candle_core::Layout,
298 absmax_s: &candle_core::CudaStorage,
299 absmax_l: &candle_core::Layout,
300 code_s: &candle_core::CudaStorage,
301 code_l: &candle_core::Layout,
302 ) -> Result<(candle_core::CudaStorage, Shape)> {
303 if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
304 candle_core::bail!("All inputs must be contiguous");
305 }
306 let input_slice = input_s.as_cuda_slice::<u8>()?;
307 let absmax_slice = absmax_s.as_cuda_slice::<f32>()?;
308 let code_slice = code_s.as_cuda_slice::<f32>()?;
309 let dev = input_s.device().clone();
310 let out = match (self.out_ty, self.quant_ty) {
311 (BnbDType::F32, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
312 self.dispatch_cuda_kernel::<f32>(
313 input_slice,
314 code_slice,
315 absmax_slice,
316 &dev,
317 ffi::dequantize_blockwise_f32_nf4,
318 )?,
319 dev,
320 ),
321 (BnbDType::F16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
322 self.dispatch_cuda_kernel::<half::f16>(
323 input_slice,
324 code_slice,
325 absmax_slice,
326 &dev,
327 ffi::dequantize_blockwise_f16_nf4,
328 )?,
329 dev,
330 ),
331 (BnbDType::BF16, BnbQuantType::Nf4) => candle_core::CudaStorage::wrap_cuda_slice(
332 self.dispatch_cuda_kernel::<half::bf16>(
333 input_slice,
334 code_slice,
335 absmax_slice,
336 &dev,
337 ffi::dequantize_blockwise_bf16_nf4,
338 )?,
339 dev,
340 ),
341
342 (BnbDType::F32, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
343 self.dispatch_cuda_kernel::<f32>(
344 input_slice,
345 code_slice,
346 absmax_slice,
347 &dev,
348 ffi::dequantize_blockwise_f32_fp4,
349 )?,
350 dev,
351 ),
352 (BnbDType::F16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
353 self.dispatch_cuda_kernel::<half::f16>(
354 input_slice,
355 code_slice,
356 absmax_slice,
357 &dev,
358 ffi::dequantize_blockwise_f16_fp4,
359 )?,
360 dev,
361 ),
362 (BnbDType::BF16, BnbQuantType::Fp4) => candle_core::CudaStorage::wrap_cuda_slice(
363 self.dispatch_cuda_kernel::<half::bf16>(
364 input_slice,
365 code_slice,
366 absmax_slice,
367 &dev,
368 ffi::dequantize_blockwise_bf16_fp4,
369 )?,
370 dev,
371 ),
372
373 (BnbDType::F32, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
374 self.dispatch_cuda_kernel::<f32>(
375 input_slice,
376 code_slice,
377 absmax_slice,
378 &dev,
379 ffi::dequantize_blockwise_f32_int8,
380 )?,
381 dev,
382 ),
383 (BnbDType::F16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
384 self.dispatch_cuda_kernel::<half::f16>(
385 input_slice,
386 code_slice,
387 absmax_slice,
388 &dev,
389 ffi::dequantize_blockwise_f16_int8,
390 )?,
391 dev,
392 ),
393 (BnbDType::BF16, BnbQuantType::Int8) => candle_core::CudaStorage::wrap_cuda_slice(
394 self.dispatch_cuda_kernel::<half::bf16>(
395 input_slice,
396 code_slice,
397 absmax_slice,
398 &dev,
399 ffi::dequantize_blockwise_bf16_int8,
400 )?,
401 dev,
402 ),
403 };
404
405 Ok((out, self.shape.clone()))
406 }
407
408 #[cfg(feature = "metal")]
409 fn metal_fwd(
410 &self,
411 input_s: &candle_core::MetalStorage,
412 input_l: &candle_core::Layout,
413 absmax_s: &candle_core::MetalStorage,
414 absmax_l: &candle_core::Layout,
415 code_s: &candle_core::MetalStorage,
416 code_l: &candle_core::Layout,
417 ) -> Result<(candle_core::MetalStorage, Shape)> {
418 use candle_core::DType;
419
420 if !(input_l.is_contiguous() && absmax_l.is_contiguous() && code_l.is_contiguous()) {
421 candle_core::bail!("All inputs must be contiguous");
422 }
423
424 let command_buffer = input_s.device().command_buffer()?;
425 command_buffer.set_label("dequant-bnb-nf4");
426
427 let device = input_s.device();
428
429 let output = device.new_buffer(
430 self.shape.elem_count(),
431 self.out_ty.into(),
432 "dequant-bnb-nf4",
433 )?;
434
435 if input_s.dtype() != DType::U8 {
436 candle_core::bail!("input must be u8");
437 }
438 if code_s.dtype() != DType::F32 {
439 candle_core::bail!("code must be f32");
440 }
441 if absmax_s.dtype() != DType::F32 {
442 candle_core::bail!("absmax must be f32");
443 }
444
445 match self.quant_ty {
446 BnbQuantType::Nf4 => crate::metal_kernels::call_dequant_bnb_nf4(
447 device.device(),
448 &command_buffer,
449 &crate::metal_kernels::Kernels::new(),
450 self.out_ty.into(),
451 input_s.buffer(),
452 absmax_s.buffer(),
453 code_s.buffer(),
454 &output,
455 self.blocksize,
456 self.n,
457 )
458 .map_err(candle_core::Error::wrap)?,
459 BnbQuantType::Fp4 => crate::metal_kernels::call_dequant_bnb_fp4(
460 device.device(),
461 &command_buffer,
462 &crate::metal_kernels::Kernels::new(),
463 self.out_ty.into(),
464 input_s.buffer(),
465 absmax_s.buffer(),
466 code_s.buffer(),
467 &output,
468 self.blocksize,
469 self.n,
470 )
471 .map_err(candle_core::Error::wrap)?,
472 BnbQuantType::Int8 => crate::metal_kernels::call_dequant_bnb_int8(
473 device.device(),
474 &command_buffer,
475 &crate::metal_kernels::Kernels::new(),
476 self.out_ty.into(),
477 input_s.buffer(),
478 absmax_s.buffer(),
479 code_s.buffer(),
480 &output,
481 self.blocksize,
482 self.n,
483 )
484 .map_err(candle_core::Error::wrap)?,
485 };
486
487 let newstorage = candle_core::MetalStorage::new(
488 output,
489 device.clone(),
490 self.shape.elem_count(),
491 self.out_ty.into(),
492 );
493 Ok((newstorage, self.shape.clone()))
494 }
495}
496
497pub fn dequantize(
498 input: &Tensor,
499 absmax: &Tensor,
500 code: &Tensor,
501 shape: Shape,
502 blocksize: usize,
503 quant_ty: BnbQuantType,
504 out_ty: BnbDType,
505) -> Result<Tensor> {
506 input.apply_op3(
507 absmax,
508 code,
509 DequantizeOp {
510 n: input.elem_count(),
511 blocksize,
512 shape,
513 quant_ty,
514 out_ty,
515 },
516 )
517}