1#[cfg(feature = "cuda")]
2use candle_core::from_storage_no_op;
3use candle_core::{CpuStorage, CustomOp1, CustomOp2, DType, Result, Tensor, WithDType};
4use float8::F8E4M3;
5use rayon::iter::{IntoParallelIterator, ParallelIterator};
6
7struct Fp8BlockwiseDequantize {
8 weight_block_size: Vec<usize>,
9 out_ty: DType,
10}
11
12impl Fp8BlockwiseDequantize {
13 fn dispatch_dequant_blockwise<T: WithDType>(
14 &self,
15 weight: &[F8E4M3],
16 scale: &[f32],
17 weight_l: &candle_core::Layout,
18 scale_l: &candle_core::Layout,
19 ) -> candle_core::Result<Vec<T>> {
20 let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]);
21 let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]);
22
23 let res = vec![T::zero(); weight.len()];
24
25 (0..grid_y).into_par_iter().for_each(|y| {
26 (0..grid_x).into_par_iter().for_each(|x| {
27 let res_ptr = res.as_ptr() as *mut T;
28
29 let scale = scale[y * scale_l.stride()[0] + x];
30
31 let start_y = y * self.weight_block_size[0];
32 let end_y = start_y + self.weight_block_size[0];
33
34 let start_x = x * self.weight_block_size[1];
35 let end_x = start_x + self.weight_block_size[1];
36
37 for weight_y in start_y..end_y {
38 if weight_y >= weight_l.dims()[0] {
39 break;
40 }
41
42 let row_offset = weight_y * weight_l.stride()[0];
43 for weight_x in start_x..end_x {
44 if weight_x >= weight_l.dims()[1] {
45 break;
46 }
47
48 let weight_pos = row_offset + weight_x;
49
50 unsafe {
52 *res_ptr.wrapping_add(weight_pos) =
53 T::from_f64((weight[weight_pos].to_f32() * scale) as f64);
54 }
55 }
56 }
57 });
58 });
59
60 Ok(res)
61 }
62}
63
64impl CustomOp2 for Fp8BlockwiseDequantize {
65 fn name(&self) -> &'static str {
66 "fp8-blockwise-dequantize"
67 }
68
69 fn cpu_fwd(
70 &self,
71 scale_s: &candle_core::CpuStorage,
72 scale_l: &candle_core::Layout,
73 weight_s: &candle_core::CpuStorage,
74 weight_l: &candle_core::Layout,
75 ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
76 let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
77 candle_core::bail!("Expected F8E4M3 weight!");
78 };
79 let candle_core::CpuStorage::F32(scale) = scale_s else {
80 candle_core::bail!("Expected F8E4M3 weight!");
81 };
82 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
83 candle_core::bail!("Expected weight to have start offset 0, continuous");
84 }
85 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
86 candle_core::bail!("Expected scales to have start offset 0, continuous");
87 }
88 if weight_l.dims().len() != 2 {
89 candle_core::bail!("Expected weight to be rank 2");
90 }
91 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
92 candle_core::bail!("Expected scale to be rank 2");
93 }
94
95 match self.out_ty {
96 DType::F32 => Ok((
97 CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
98 weight_l.shape().clone(),
99 )),
100 DType::BF16 => Ok((
101 CpuStorage::BF16(
102 self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?,
103 ),
104 weight_l.shape().clone(),
105 )),
106 DType::F16 => Ok((
107 CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
108 weight_l.shape().clone(),
109 )),
110 other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
111 }
112 }
113
114 #[cfg(feature = "cuda")]
115 fn cuda_fwd(
116 &self,
117 scale_s: &candle_core::CudaStorage,
118 scale_l: &candle_core::Layout,
119 weight_s: &candle_core::CudaStorage,
120 weight_l: &candle_core::Layout,
121 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
122 use candle_core::{backend::BackendStorage, CudaStorage};
123 use half::{bf16, f16};
124
125 use crate::{blockwise_fp8::ffi, utils::slice_ptr};
126
127 if !ffi::HAVE_BLOCKWISE_DEQUANT_KERNELS {
128 candle_core::bail!("Do not have blockwise FP8 dequant kernels.");
129 }
130
131 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
132 candle_core::bail!("Expected weight to have start offset 0, continuous");
133 }
134 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
135 candle_core::bail!("Expected scales to have start offset 0, continuous");
136 }
137 if weight_l.dims().len() != 2 {
138 candle_core::bail!("Expected weight to be rank 2");
139 }
140 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
141 candle_core::bail!("Expected scale to be rank 2");
142 }
143
144 let dev = weight_s.device();
145
146 let (weight, _weight_guard) =
147 slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
148 let (scale, _scale_guard) =
149 slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
150
151 let weight_height = weight_l.dim(0)? as i32;
152 let weight_block_size_x = self.weight_block_size[0] as i32;
153 let weight_width = weight_l.dim(1)? as i32;
154 let weight_block_size_y = self.weight_block_size[1] as i32;
155 let scale_stride = scale_l.stride()[0] as i32;
156 let weight_row_stride = weight_l.stride()[0] as i32;
157
158 let res = match self.out_ty {
159 DType::F32 => {
160 let output = weight_s
161 .device()
162 .alloc_zeros::<f32>(weight_l.shape().elem_count())?;
163 let (output_ptr, output_guard) = slice_ptr(&output, 0);
164 unsafe {
165 ffi::launch_dequant_fp8_blockwise_kernel_f32(
166 weight as *const _,
167 scale as *const _,
168 output_ptr as *mut _,
169 weight_height,
170 weight_width,
171 weight_row_stride,
172 scale_stride,
173 weight_block_size_y,
174 weight_block_size_x,
175 dev.cuda_stream().cu_stream(),
176 )
177 };
178 drop(output_guard);
179 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
180 }
181 DType::F16 => {
182 let output = weight_s
183 .device()
184 .alloc_zeros::<f16>(weight_l.shape().elem_count())?;
185 let (output_ptr, output_guard) = slice_ptr(&output, 0);
186 unsafe {
187 ffi::launch_dequant_fp8_blockwise_kernel_f16(
188 weight as *const _,
189 scale as *const _,
190 output_ptr as *mut _,
191 weight_height,
192 weight_width,
193 weight_row_stride,
194 scale_stride,
195 weight_block_size_y,
196 weight_block_size_x,
197 dev.cuda_stream().cu_stream(),
198 )
199 };
200 drop(output_guard);
201 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
202 }
203 DType::BF16 => {
204 let output = weight_s
205 .device()
206 .alloc_zeros::<bf16>(weight_l.shape().elem_count())?;
207 let (output_ptr, output_guard) = slice_ptr(&output, 0);
208 unsafe {
209 ffi::launch_dequant_fp8_blockwise_kernel_bf16(
210 weight as *const _,
211 scale as *const _,
212 output_ptr as *mut _,
213 weight_height,
214 weight_width,
215 weight_row_stride,
216 scale_stride,
217 weight_block_size_y,
218 weight_block_size_x,
219 dev.cuda_stream().cu_stream(),
220 )
221 };
222 drop(output_guard);
223 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
224 }
225 other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
226 };
227
228 Ok((res, weight_l.shape().clone()))
229 }
230
231 #[cfg(feature = "metal")]
232 fn metal_fwd(
233 &self,
234 scale_s: &candle_core::MetalStorage,
235 scale_l: &candle_core::Layout,
236 weight_s: &candle_core::MetalStorage,
237 weight_l: &candle_core::Layout,
238 ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
239 use candle_core::backend::BackendStorage;
240
241 if weight_l.start_offset() != 0
242 || !weight_l.is_contiguous()
243 || weight_s.dtype() != DType::F8E4M3
244 {
245 candle_core::bail!("Expected f8e4m3 weight to have start offset 0, continuous");
246 }
247 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() || scale_s.dtype() != DType::F32
248 {
249 candle_core::bail!("Expected f32 scales to have start offset 0, continuous");
250 }
251 if weight_l.dims().len() != 2 {
252 candle_core::bail!("Expected weight to be rank 2");
253 }
254 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
255 candle_core::bail!("Expected scale to be rank 2");
256 }
257
258 let command_buffer = weight_s.device().command_buffer()?;
259 command_buffer.set_label("dequant-blockwise-fp8");
260
261 let device = weight_s.device();
262
263 let out_shape = weight_l.shape().clone();
264
265 let output = device.new_buffer(
266 out_shape.elem_count(),
267 weight_s.dtype(),
268 "dequant-blockwise-fp8",
269 )?;
270
271 let weight_height = weight_l.dim(0)? as u32;
272 let weight_block_size_x = self.weight_block_size[0] as u32;
273 let weight_width = weight_l.dim(1)? as u32;
274 let weight_block_size_y = self.weight_block_size[1] as u32;
275 let scale_stride = scale_l.stride()[0] as u32;
276 let weight_row_stride = weight_l.stride()[0] as u32;
277
278 crate::metal_kernels::call_dequant_blockwise_fp8(
279 device.device(),
280 &command_buffer,
281 &crate::metal_kernels::Kernels::new(),
282 self.out_ty,
283 weight_s.buffer(),
284 scale_s.buffer(),
285 &output,
286 weight_height,
287 weight_width,
288 weight_row_stride,
289 scale_stride,
290 weight_block_size_y,
291 weight_block_size_x,
292 )
293 .map_err(candle_core::Error::wrap)?;
294
295 let newstorage = candle_core::MetalStorage::new(
296 output,
297 device.clone(),
298 out_shape.elem_count(),
299 self.out_ty,
300 );
301 Ok((newstorage, out_shape))
302 }
303}
304
305pub fn fp8_blockwise_dequantize(
310 weight: &Tensor,
311 inv_scales: &Tensor,
312 weight_block_size: Vec<usize>,
313 out_ty: DType,
314) -> Result<Tensor> {
315 inv_scales.apply_op2_no_bwd(
316 weight,
317 &Fp8BlockwiseDequantize {
318 weight_block_size,
319 out_ty,
320 },
321 )
322}
323
324#[allow(dead_code)]
325struct Fp8BlockwiseQuantize {
326 weight_block_size: Vec<usize>,
327}
328
329impl Fp8BlockwiseQuantize {
330 #[allow(dead_code)]
331 fn dispatch_quant_blockwise<T: WithDType>(
332 &self,
333 input: &[T],
334 input_l: &candle_core::Layout,
335 ) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
336 let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
337 let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
338
339 let weight = vec![F8E4M3::from_f32(0.0); input.len()];
340 let scale = vec![0f32; grid_y * grid_x];
341
342 (0..grid_y).into_par_iter().for_each(|y| {
343 (0..grid_x).into_par_iter().for_each(|x| {
344 let weight_ptr = weight.as_ptr() as *mut F8E4M3;
345 let scale_ptr = scale.as_ptr() as *mut f32;
346
347 let start_y = y * self.weight_block_size[0];
348 let end_y = start_y + self.weight_block_size[0];
349
350 let start_x = x * self.weight_block_size[1];
351 let end_x = start_x + self.weight_block_size[1];
352
353 let mut max_abs = 0f32;
355 for weight_y in start_y..end_y {
356 if weight_y >= input_l.dims()[0] {
357 break;
358 }
359
360 let row_offset = weight_y * input_l.stride()[0];
361 for weight_x in start_x..end_x {
362 if weight_x >= input_l.dims()[1] {
363 break;
364 }
365
366 let pos = row_offset + weight_x;
367 let val = input[pos].to_f64() as f32;
368 let abs_val = val.abs();
369 if abs_val > max_abs {
370 max_abs = abs_val;
371 }
372 }
373 }
374
375 let block_scale = if max_abs > 0.0 {
377 max_abs / 448.0
378 } else {
379 1e-12
380 };
381
382 unsafe {
384 *scale_ptr.wrapping_add(y * grid_x + x) = block_scale;
385 }
386
387 for weight_y in start_y..end_y {
389 if weight_y >= input_l.dims()[0] {
390 break;
391 }
392
393 let row_offset = weight_y * input_l.stride()[0];
394 for weight_x in start_x..end_x {
395 if weight_x >= input_l.dims()[1] {
396 break;
397 }
398
399 let pos = row_offset + weight_x;
400 let val = input[pos].to_f64() as f32;
401 let scaled_val = (val / block_scale).clamp(-448.0, 448.0);
402
403 unsafe {
405 *weight_ptr.wrapping_add(pos) = F8E4M3::from_f32(scaled_val);
406 }
407 }
408 }
409 });
410 });
411
412 Ok((weight, scale))
413 }
414}
415
416impl CustomOp1 for Fp8BlockwiseQuantize {
417 fn name(&self) -> &'static str {
418 "fp8-blockwise-quantize"
419 }
420
421 fn cpu_fwd(
422 &self,
423 input_s: &candle_core::CpuStorage,
424 input_l: &candle_core::Layout,
425 ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
426 if input_l.start_offset() != 0 || !input_l.is_contiguous() {
427 candle_core::bail!("Expected input to have start offset 0, continuous");
428 }
429 if input_l.dims().len() != 2 {
430 candle_core::bail!("Expected input to be rank 2");
431 }
432 if self.weight_block_size.len() != 2 {
433 candle_core::bail!("Expected weight_block_size to have length 2");
434 }
435
436 let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
437 let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
438
439 let (weight, scale) = match input_s {
440 CpuStorage::F32(input) => self.dispatch_quant_blockwise(input, input_l)?,
441 CpuStorage::F16(input) => self.dispatch_quant_blockwise(input, input_l)?,
442 CpuStorage::BF16(input) => self.dispatch_quant_blockwise(input, input_l)?,
443 other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
444 };
445
446 let mut packed = Vec::with_capacity(weight.len() + scale.len());
449 packed.extend_from_slice(&weight);
450
451 for &s in &scale {
453 packed.push(F8E4M3::from_f32(s));
454 }
455
456 Ok((
457 CpuStorage::F8E4M3(packed),
458 candle_core::Shape::from_dims(&[
459 input_l.dims()[0] + grid_y,
460 input_l.dims()[1].max(grid_x),
461 ]),
462 ))
463 }
464
465 #[cfg(feature = "cuda")]
466 fn cuda_fwd(
467 &self,
468 input_s: &candle_core::CudaStorage,
469 input_l: &candle_core::Layout,
470 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
471 use candle_core::{backend::BackendStorage, CudaStorage};
472 use half::{bf16, f16};
473
474 use crate::{blockwise_fp8::ffi, utils::slice_ptr};
475
476 if !ffi::HAVE_BLOCKWISE_QUANT_KERNELS {
477 candle_core::bail!("Do not have blockwise FP8 quant kernels.");
478 }
479
480 if input_l.start_offset() != 0 || !input_l.is_contiguous() {
481 candle_core::bail!("Expected input to have start offset 0, continuous");
482 }
483 if input_l.dims().len() != 2 {
484 candle_core::bail!("Expected input to be rank 2");
485 }
486 if self.weight_block_size.len() != 2 {
487 candle_core::bail!("Expected weight_block_size to have length 2");
488 }
489
490 let dev = input_s.device();
491
492 let weight_height = input_l.dim(0)? as i32;
493 let weight_block_size_y = self.weight_block_size[0] as i32;
494 let weight_width = input_l.dim(1)? as i32;
495 let weight_block_size_x = self.weight_block_size[1] as i32;
496 let weight_row_stride = input_l.stride()[0] as i32;
497
498 let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
499 let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
500 let scale_stride = grid_x as i32;
501
502 let weight_output = dev.alloc_zeros::<F8E4M3>(input_l.shape().elem_count())?;
504 let scale_output = dev.alloc_zeros::<f32>(grid_y * grid_x)?;
505
506 let (weight_ptr, weight_guard) = slice_ptr(&weight_output, 0);
507 let (scale_ptr, scale_guard) = slice_ptr(&scale_output, 0);
508
509 match input_s.dtype() {
510 DType::F32 => {
511 let (input, _input_guard) =
512 slice_ptr(input_s.as_cuda_slice::<f32>()?, input_l.start_offset());
513 unsafe {
514 ffi::launch_quant_fp8_blockwise_kernel_f32(
515 input as *const _,
516 weight_ptr as *mut _,
517 scale_ptr as *mut _,
518 weight_height,
519 weight_width,
520 weight_row_stride,
521 scale_stride,
522 weight_block_size_y,
523 weight_block_size_x,
524 dev.cuda_stream().cu_stream(),
525 )
526 };
527 }
528 DType::F16 => {
529 let (input, _input_guard) =
530 slice_ptr(input_s.as_cuda_slice::<f16>()?, input_l.start_offset());
531 unsafe {
532 ffi::launch_quant_fp8_blockwise_kernel_f16(
533 input as *const _,
534 weight_ptr as *mut _,
535 scale_ptr as *mut _,
536 weight_height,
537 weight_width,
538 weight_row_stride,
539 scale_stride,
540 weight_block_size_y,
541 weight_block_size_x,
542 dev.cuda_stream().cu_stream(),
543 )
544 };
545 }
546 DType::BF16 => {
547 let (input, _input_guard) =
548 slice_ptr(input_s.as_cuda_slice::<bf16>()?, input_l.start_offset());
549 unsafe {
550 ffi::launch_quant_fp8_blockwise_kernel_bf16(
551 input as *const _,
552 weight_ptr as *mut _,
553 scale_ptr as *mut _,
554 weight_height,
555 weight_width,
556 weight_row_stride,
557 scale_stride,
558 weight_block_size_y,
559 weight_block_size_x,
560 dev.cuda_stream().cu_stream(),
561 )
562 };
563 }
564 other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
565 }
566
567 drop(weight_guard);
568 drop(scale_guard);
569
570 let res = CudaStorage::wrap_cuda_slice(weight_output, input_s.device().clone());
572 Ok((res, input_l.shape().clone()))
573 }
574
575 #[cfg(feature = "metal")]
576 fn metal_fwd(
577 &self,
578 _input_s: &candle_core::MetalStorage,
579 _input_l: &candle_core::Layout,
580 ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
581 candle_core::bail!("FP8 blockwise quantization not yet implemented for Metal");
582 }
583}
584
585pub fn fp8_blockwise_quantize(
591 #[allow(unused_variables)] input: &Tensor,
592 #[allow(unused_variables)] weight_block_size: Vec<usize>,
593) -> Result<(Tensor, Tensor)> {
594 #[cfg(feature = "cuda")]
597 {
598 use candle_core::{CudaStorage, Device, Storage};
599 use half::{bf16, f16};
600
601 use crate::{blockwise_fp8::ffi, utils::slice_ptr};
602
603 if !matches!(input.device(), Device::Cuda(_)) {
604 candle_core::bail!("FP8 blockwise quantization only supported on CUDA for now");
605 }
606
607 if !ffi::HAVE_BLOCKWISE_QUANT_KERNELS {
608 candle_core::bail!("Do not have blockwise FP8 quant kernels.");
609 }
610
611 let input_l = input.layout();
612 if input_l.start_offset() != 0 || !input_l.is_contiguous() {
613 candle_core::bail!("Expected input to have start offset 0, continuous");
614 }
615 if input.dims().len() != 2 {
616 candle_core::bail!("Expected input to be rank 2");
617 }
618 if weight_block_size.len() != 2 {
619 candle_core::bail!("Expected weight_block_size to have length 2");
620 }
621
622 let dev = match input.device() {
623 Device::Cuda(dev) => dev,
624 _ => unreachable!(),
625 };
626
627 let weight_height = input.dim(0)? as i32;
628 let weight_block_size_y = weight_block_size[0] as i32;
629 let weight_width = input.dim(1)? as i32;
630 let weight_block_size_x = weight_block_size[1] as i32;
631 let weight_row_stride = input_l.stride()[0] as i32;
632
633 let grid_y = input.dim(0)?.div_ceil(weight_block_size[0]);
634 let grid_x = input.dim(1)?.div_ceil(weight_block_size[1]);
635 let scale_stride = grid_x as i32;
636
637 let weight_output = dev.alloc_zeros::<F8E4M3>(input.shape().elem_count())?;
639 let scale_output = dev.alloc_zeros::<f32>(grid_y * grid_x)?;
640
641 let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
642 let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
643
644 match input.dtype() {
645 DType::F32 => {
646 let input_storage = input.storage_and_layout().0;
647 let input_s = match &*input_storage {
648 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
649 _ => candle_core::bail!("Expected CUDA storage"),
650 };
651 let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
652 unsafe {
653 ffi::launch_quant_fp8_blockwise_kernel_f32(
654 input_ptr as *const _,
655 weight_ptr as *mut _,
656 scale_ptr as *mut _,
657 weight_height,
658 weight_width,
659 weight_row_stride,
660 scale_stride,
661 weight_block_size_y,
662 weight_block_size_x,
663 dev.cuda_stream().cu_stream(),
664 )
665 };
666 }
667 DType::F16 => {
668 let input_storage = input.storage_and_layout().0;
669 let input_s = match &*input_storage {
670 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
671 _ => candle_core::bail!("Expected CUDA storage"),
672 };
673 let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
674 unsafe {
675 ffi::launch_quant_fp8_blockwise_kernel_f16(
676 input_ptr as *const _,
677 weight_ptr as *mut _,
678 scale_ptr as *mut _,
679 weight_height,
680 weight_width,
681 weight_row_stride,
682 scale_stride,
683 weight_block_size_y,
684 weight_block_size_x,
685 dev.cuda_stream().cu_stream(),
686 )
687 };
688 }
689 DType::BF16 => {
690 let input_storage = input.storage_and_layout().0;
691 let input_s = match &*input_storage {
692 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
693 _ => candle_core::bail!("Expected CUDA storage"),
694 };
695 let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
696 unsafe {
697 ffi::launch_quant_fp8_blockwise_kernel_bf16(
698 input_ptr as *const _,
699 weight_ptr as *mut _,
700 scale_ptr as *mut _,
701 weight_height,
702 weight_width,
703 weight_row_stride,
704 scale_stride,
705 weight_block_size_y,
706 weight_block_size_x,
707 dev.cuda_stream().cu_stream(),
708 )
709 };
710 }
711 other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
712 }
713
714 drop(_weight_guard);
716 drop(_scale_guard);
717
718 let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
720 let weight =
721 from_storage_no_op(Storage::Cuda(weight_storage), input.shape().clone(), false);
722
723 let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
725 let scale = from_storage_no_op(
726 Storage::Cuda(scale_storage),
727 candle_core::Shape::from_dims(&[grid_y, grid_x]),
728 false,
729 );
730
731 Ok((weight, scale))
732 }
733
734 #[cfg(not(feature = "cuda"))]
735 {
736 candle_core::bail!("FP8 blockwise quantization requires CUDA feature");
737 }
738}
739
740#[cfg(test)]
741#[allow(unused_imports)]
742mod tests {
743 use candle_core::{DType, Device, Result, Tensor};
744 use candle_nn::{Linear, Module};
745 use half::bf16;
746 use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
747
748 use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors};
749
750 #[test]
751 fn test_fp8_blockwise_dequant() -> Result<()> {
752 let dev = &Device::Cpu;
753 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
754 let weight_block_size = vec![2, 2];
755 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
756
757 let dequant =
758 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
759
760 let res = dequant.to_vec2::<f32>()?;
761 assert_eq!(
762 res,
763 vec![
764 vec![0., 0., 1., 1., 2.],
765 vec![0., 0., 1., 1., 2.],
766 vec![3., 3., 4., 4., 5.],
767 vec![3., 3., 4., 4., 5.],
768 vec![6., 6., 7., 7., 8.],
769 ]
770 );
771
772 Ok(())
773 }
774
775 #[cfg(feature = "cuda")]
776 #[test]
777 fn test_fp8_blockwise_dequant_cuda() -> Result<()> {
778 let truth = {
779 let dev = &Device::Cpu;
780 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
781 let weight_block_size = vec![2, 2];
782 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
783
784 let dequant =
785 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
786
787 dequant.to_vec2::<f32>()?
788 };
789 let test = {
790 let dev = &Device::new_cuda(0)?;
791 let weight_cpu = Tensor::ones((5, 5), DType::F8E4M3, &Device::Cpu)?;
793 let weight = weight_cpu.to_device(dev)?;
794 let weight_block_size = vec![2, 2];
795 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
796
797 let dequant =
798 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
799
800 dequant.to_vec2::<f32>()?
801 };
802
803 assert_eq!(test, truth);
804 assert_eq!(
805 test,
806 vec![
807 vec![0., 0., 1., 1., 2.],
808 vec![0., 0., 1., 1., 2.],
809 vec![3., 3., 4., 4., 5.],
810 vec![3., 3., 4., 4., 5.],
811 vec![6., 6., 7., 7., 8.],
812 ]
813 );
814
815 Ok(())
816 }
817
818 #[test]
819 fn test_fp8_blockwise_dequant_bf16() -> Result<()> {
820 let dev = &Device::Cpu;
821 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
822 let weight_block_size = vec![2, 2];
823 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
824
825 let dequant =
826 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::BF16)?;
827
828 let res = dequant.to_vec2::<bf16>()?;
829 assert_eq!(
830 res,
831 vec![
832 vec![
833 bf16::from_f32(0.),
834 bf16::from_f32(0.),
835 bf16::from_f32(1.),
836 bf16::from_f32(1.),
837 bf16::from_f32(2.)
838 ],
839 vec![
840 bf16::from_f32(0.),
841 bf16::from_f32(0.),
842 bf16::from_f32(1.),
843 bf16::from_f32(1.),
844 bf16::from_f32(2.)
845 ],
846 vec![
847 bf16::from_f32(3.),
848 bf16::from_f32(3.),
849 bf16::from_f32(4.),
850 bf16::from_f32(4.),
851 bf16::from_f32(5.)
852 ],
853 vec![
854 bf16::from_f32(3.),
855 bf16::from_f32(3.),
856 bf16::from_f32(4.),
857 bf16::from_f32(4.),
858 bf16::from_f32(5.)
859 ],
860 vec![
861 bf16::from_f32(6.),
862 bf16::from_f32(6.),
863 bf16::from_f32(7.),
864 bf16::from_f32(7.),
865 bf16::from_f32(8.)
866 ],
867 ]
868 );
869
870 Ok(())
871 }
872
873 #[cfg(feature = "cuda")]
874 #[test]
875 fn test_fp8_blockwise_dequant_cuda_bf16() -> Result<()> {
876 let truth = {
877 let dev = &Device::Cpu;
878 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
879 let weight_block_size = vec![2, 2];
880 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
881
882 let dequant = ops::fp8_blockwise_dequantize(
883 &weight,
884 &inv_scales,
885 weight_block_size,
886 DType::BF16,
887 )?;
888
889 dequant.to_vec2::<bf16>()?
890 };
891 let test = {
892 let dev = &Device::new_cuda(0)?;
893 let weight_cpu = Tensor::ones((5, 5), DType::F8E4M3, &Device::Cpu)?;
895 let weight = weight_cpu.to_device(dev)?;
896 let weight_block_size = vec![2, 2];
897 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
898
899 let dequant = ops::fp8_blockwise_dequantize(
900 &weight,
901 &inv_scales,
902 weight_block_size,
903 DType::BF16,
904 )?;
905
906 dequant.to_vec2::<bf16>()?
907 };
908
909 assert_eq!(test, truth);
910 assert_eq!(
911 test,
912 vec![
913 vec![
914 bf16::from_f32(0.),
915 bf16::from_f32(0.),
916 bf16::from_f32(1.),
917 bf16::from_f32(1.),
918 bf16::from_f32(2.)
919 ],
920 vec![
921 bf16::from_f32(0.),
922 bf16::from_f32(0.),
923 bf16::from_f32(1.),
924 bf16::from_f32(1.),
925 bf16::from_f32(2.)
926 ],
927 vec![
928 bf16::from_f32(3.),
929 bf16::from_f32(3.),
930 bf16::from_f32(4.),
931 bf16::from_f32(4.),
932 bf16::from_f32(5.)
933 ],
934 vec![
935 bf16::from_f32(3.),
936 bf16::from_f32(3.),
937 bf16::from_f32(4.),
938 bf16::from_f32(4.),
939 bf16::from_f32(5.)
940 ],
941 vec![
942 bf16::from_f32(6.),
943 bf16::from_f32(6.),
944 bf16::from_f32(7.),
945 bf16::from_f32(7.),
946 bf16::from_f32(8.)
947 ],
948 ]
949 );
950
951 Ok(())
952 }
953
954 #[cfg(feature = "cuda")]
955 #[test]
956 fn test_fp8_blockwise_quant_dequant_roundtrip() -> Result<()> {
957 let dev = &Device::new_cuda(0)?;
958
959 let input = Tensor::randn(0f32, 2f32, (8, 8), dev)?;
961 let weight_block_size = vec![4, 4];
962
963 let (quantized, scales) = ops::fp8_blockwise_quantize(&input, weight_block_size.clone())?;
965
966 assert_eq!(quantized.shape(), input.shape());
968 assert_eq!(scales.dims2()?, (2, 2)); let dequantized =
972 ops::fp8_blockwise_dequantize(&quantized, &scales, weight_block_size, input.dtype())?;
973
974 assert_eq!(dequantized.shape(), input.shape());
976
977 let input_vec = input.to_vec2::<f32>()?;
980 let dequant_vec = dequantized.to_vec2::<f32>()?;
981
982 let mut max_error = 0f32;
983 for (row_in, row_out) in input_vec.iter().zip(dequant_vec.iter()) {
984 for (val_in, val_out) in row_in.iter().zip(row_out.iter()) {
985 let error = (val_in - val_out).abs();
986 max_error = max_error.max(error);
987 }
988 }
989
990 assert!(max_error < 0.16, "Max error {} is too large", max_error);
993
994 Ok(())
995 }
996
997 #[cfg(feature = "cuda")]
998 #[test]
999 fn test_blockwise_fp8_gemm() -> Result<()> {
1000 let dev = Device::cuda_if_available(0)?;
1001
1002 let api = ApiBuilder::new().with_progress(true).build().unwrap();
1003 let api = api.repo(Repo::with_revision(
1004 "EricB/mistralrs_tests".to_string(),
1005 RepoType::Model,
1006 "main".to_string(),
1007 ));
1008
1009 let filename = api.get("test_fp8.safetensors").unwrap();
1010 let vb = unsafe { MmapedSafetensors::new(filename)? };
1011
1012 let weight = vb.load("weight", &dev, None)?;
1013 assert_eq!((7168, 2048), weight.dims2()?);
1014 assert_eq!(DType::F8E4M3, weight.dtype());
1015
1016 let scale = vb.load("scale", &dev, None)?;
1017 assert_eq!((56, 16), scale.dims2()?);
1018 assert_eq!(DType::F32, scale.dtype());
1019
1020 let weight_block_size = vec![128, 128];
1021
1022 let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?;
1024
1025 let truth = {
1026 let weight_dq =
1027 ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?;
1028
1029 let lin_dq = Linear::new(weight_dq, None);
1030 lin_dq.forward(&xs)?
1031 };
1032
1033 assert_eq!((32, 7168), truth.dims2()?);
1035
1036 Ok(())
1037 }
1038}