1use candle_core::{CpuStorage, CustomOp1, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5struct Fp8BlockwiseDequantize {
6 weight_block_size: Vec<usize>,
7 out_ty: DType,
8}
9
10impl Fp8BlockwiseDequantize {
11 fn dispatch_dequant_blockwise<T: WithDType>(
12 &self,
13 weight: &[F8E4M3],
14 scale: &[f32],
15 weight_l: &candle_core::Layout,
16 scale_l: &candle_core::Layout,
17 ) -> candle_core::Result<Vec<T>> {
18 let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]);
19 let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]);
20
21 let res = vec![T::zero(); weight.len()];
22
23 (0..grid_y).into_par_iter().for_each(|y| {
24 (0..grid_x).into_par_iter().for_each(|x| {
25 let res_ptr = res.as_ptr() as *mut T;
26
27 let scale = scale[y * scale_l.stride()[0] + x];
28
29 let start_y = y * self.weight_block_size[0];
30 let end_y = start_y + self.weight_block_size[0];
31
32 let start_x = x * self.weight_block_size[1];
33 let end_x = start_x + self.weight_block_size[1];
34
35 for weight_y in start_y..end_y {
36 if weight_y >= weight_l.dims()[0] {
37 break;
38 }
39
40 let row_offset = weight_y * weight_l.stride()[0];
41 for weight_x in start_x..end_x {
42 if weight_x >= weight_l.dims()[1] {
43 break;
44 }
45
46 let weight_pos = row_offset + weight_x;
47
48 unsafe {
50 *res_ptr.wrapping_add(weight_pos) =
51 T::from_f64((weight[weight_pos].to_f32() * scale) as f64);
52 }
53 }
54 }
55 });
56 });
57
58 Ok(res)
59 }
60}
61
62impl CustomOp2 for Fp8BlockwiseDequantize {
63 fn name(&self) -> &'static str {
64 "fp8-blockwise-dequantize"
65 }
66
67 fn cpu_fwd(
68 &self,
69 scale_s: &candle_core::CpuStorage,
70 scale_l: &candle_core::Layout,
71 weight_s: &candle_core::CpuStorage,
72 weight_l: &candle_core::Layout,
73 ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
74 let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
75 candle_core::bail!("Expected F8E4M3 weight!");
76 };
77 let candle_core::CpuStorage::F32(scale) = scale_s else {
78 candle_core::bail!("Expected F8E4M3 weight!");
79 };
80 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
81 candle_core::bail!("Expected weight to have start offset 0, continuous");
82 }
83 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
84 candle_core::bail!("Expected scales to have start offset 0, continuous");
85 }
86 if weight_l.dims().len() != 2 {
87 candle_core::bail!("Expected weight to be rank 2");
88 }
89 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
90 candle_core::bail!("Expected scale to be rank 2");
91 }
92
93 match self.out_ty {
94 DType::F32 => Ok((
95 CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
96 weight_l.shape().clone(),
97 )),
98 DType::BF16 => Ok((
99 CpuStorage::BF16(
100 self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?,
101 ),
102 weight_l.shape().clone(),
103 )),
104 DType::F16 => Ok((
105 CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
106 weight_l.shape().clone(),
107 )),
108 other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
109 }
110 }
111
112 #[cfg(feature = "cuda")]
113 fn cuda_fwd(
114 &self,
115 scale_s: &candle_core::CudaStorage,
116 scale_l: &candle_core::Layout,
117 weight_s: &candle_core::CudaStorage,
118 weight_l: &candle_core::Layout,
119 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
120 use candle_core::{backend::BackendStorage, CudaStorage};
121 use half::{bf16, f16};
122
123 use crate::{blockwise_fp8::ffi, utils::slice_ptr};
124
125 if !ffi::HAVE_BLOCKWISE_DEQUANT_KERNELS {
126 candle_core::bail!("Do not have blockwise FP8 dequant kernels.");
127 }
128
129 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
130 candle_core::bail!("Expected weight to have start offset 0, continuous");
131 }
132 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
133 candle_core::bail!("Expected scales to have start offset 0, continuous");
134 }
135 if weight_l.dims().len() != 2 {
136 candle_core::bail!("Expected weight to be rank 2");
137 }
138 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
139 candle_core::bail!("Expected scale to be rank 2");
140 }
141
142 let dev = weight_s.device();
143
144 let (weight, _weight_guard) =
145 slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
146 let (scale, _scale_guard) =
147 slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
148
149 let weight_height = weight_l.dim(0)? as i32;
150 let weight_block_size_x = self.weight_block_size[0] as i32;
151 let weight_width = weight_l.dim(1)? as i32;
152 let weight_block_size_y = self.weight_block_size[1] as i32;
153 let scale_stride = scale_l.stride()[0] as i32;
154 let weight_row_stride = weight_l.stride()[0] as i32;
155
156 let res = match self.out_ty {
157 DType::F32 => {
158 let output = weight_s
159 .device()
160 .alloc_zeros::<f32>(weight_l.shape().elem_count())?;
161 let (output_ptr, output_guard) = slice_ptr(&output, 0);
162 unsafe {
163 ffi::launch_dequant_fp8_blockwise_kernel_f32(
164 weight as *const _,
165 scale as *const _,
166 output_ptr as *mut _,
167 weight_height,
168 weight_width,
169 weight_row_stride,
170 scale_stride,
171 weight_block_size_y,
172 weight_block_size_x,
173 dev.cuda_stream().cu_stream(),
174 )
175 };
176 drop(output_guard);
177 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
178 }
179 DType::F16 => {
180 let output = weight_s
181 .device()
182 .alloc_zeros::<f16>(weight_l.shape().elem_count())?;
183 let (output_ptr, output_guard) = slice_ptr(&output, 0);
184 unsafe {
185 ffi::launch_dequant_fp8_blockwise_kernel_f16(
186 weight as *const _,
187 scale as *const _,
188 output_ptr as *mut _,
189 weight_height,
190 weight_width,
191 weight_row_stride,
192 scale_stride,
193 weight_block_size_y,
194 weight_block_size_x,
195 dev.cuda_stream().cu_stream(),
196 )
197 };
198 drop(output_guard);
199 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
200 }
201 DType::BF16 => {
202 let output = weight_s
203 .device()
204 .alloc_zeros::<bf16>(weight_l.shape().elem_count())?;
205 let (output_ptr, output_guard) = slice_ptr(&output, 0);
206 unsafe {
207 ffi::launch_dequant_fp8_blockwise_kernel_bf16(
208 weight as *const _,
209 scale as *const _,
210 output_ptr as *mut _,
211 weight_height,
212 weight_width,
213 weight_row_stride,
214 scale_stride,
215 weight_block_size_y,
216 weight_block_size_x,
217 dev.cuda_stream().cu_stream(),
218 )
219 };
220 drop(output_guard);
221 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
222 }
223 other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
224 };
225
226 Ok((res, weight_l.shape().clone()))
227 }
228
229 #[cfg(feature = "metal")]
230 fn metal_fwd(
231 &self,
232 scale_s: &candle_core::MetalStorage,
233 scale_l: &candle_core::Layout,
234 weight_s: &candle_core::MetalStorage,
235 weight_l: &candle_core::Layout,
236 ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
237 use candle_core::backend::BackendStorage;
238
239 if weight_l.start_offset() != 0
240 || !weight_l.is_contiguous()
241 || weight_s.dtype() != DType::F8E4M3
242 {
243 candle_core::bail!("Expected f8e4m3 weight to have start offset 0, continuous");
244 }
245 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() || scale_s.dtype() != DType::F32
246 {
247 candle_core::bail!("Expected f32 scales to have start offset 0, continuous");
248 }
249 if weight_l.dims().len() != 2 {
250 candle_core::bail!("Expected weight to be rank 2");
251 }
252 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
253 candle_core::bail!("Expected scale to be rank 2");
254 }
255
256 let encoder = weight_s.device().command_encoder()?;
257 encoder.set_label("dequant-blockwise-fp8");
258
259 let device = weight_s.device();
260
261 let out_shape = weight_l.shape().clone();
262
263 let output = device.new_buffer(
264 out_shape.elem_count(),
265 weight_s.dtype(),
266 "dequant-blockwise-fp8",
267 )?;
268
269 let weight_height = weight_l.dim(0)? as u32;
270 let weight_block_size_x = self.weight_block_size[0] as u32;
271 let weight_width = weight_l.dim(1)? as u32;
272 let weight_block_size_y = self.weight_block_size[1] as u32;
273 let scale_stride = scale_l.stride()[0] as u32;
274 let weight_row_stride = weight_l.stride()[0] as u32;
275
276 crate::metal_kernels::call_dequant_blockwise_fp8(
277 device.device(),
278 &encoder,
279 &crate::metal_kernels::Kernels::new(),
280 self.out_ty,
281 weight_s.buffer(),
282 scale_s.buffer(),
283 &output,
284 weight_height,
285 weight_width,
286 weight_row_stride,
287 scale_stride,
288 weight_block_size_y,
289 weight_block_size_x,
290 )
291 .map_err(candle_core::Error::wrap)?;
292
293 let newstorage = candle_core::MetalStorage::new(
294 output,
295 device.clone(),
296 out_shape.elem_count(),
297 self.out_ty,
298 );
299 Ok((newstorage, out_shape))
300 }
301}
302
303pub fn fp8_blockwise_dequantize(
308 weight: &Tensor,
309 inv_scales: &Tensor,
310 weight_block_size: Vec<usize>,
311 out_ty: DType,
312) -> Result<Tensor> {
313 inv_scales.apply_op2_no_bwd(
314 weight,
315 &Fp8BlockwiseDequantize {
316 weight_block_size,
317 out_ty,
318 },
319 )
320}
321
322#[allow(dead_code)]
323struct Fp8BlockwiseQuantize {
324 weight_block_size: Vec<usize>,
325}
326
327impl Fp8BlockwiseQuantize {
328 #[allow(dead_code)]
329 fn dispatch_quant_blockwise<T: WithDType>(
330 &self,
331 input: &[T],
332 input_l: &candle_core::Layout,
333 ) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
334 let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
335 let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
336
337 let weight = vec![F8E4M3::from_f32(0.0); input.len()];
338 let scale = vec![0f32; grid_y * grid_x];
339
340 (0..grid_y).into_par_iter().for_each(|y| {
341 (0..grid_x).into_par_iter().for_each(|x| {
342 let weight_ptr = weight.as_ptr() as *mut F8E4M3;
343 let scale_ptr = scale.as_ptr() as *mut f32;
344
345 let start_y = y * self.weight_block_size[0];
346 let end_y = start_y + self.weight_block_size[0];
347
348 let start_x = x * self.weight_block_size[1];
349 let end_x = start_x + self.weight_block_size[1];
350
351 let mut max_abs = 0f32;
353 for weight_y in start_y..end_y {
354 if weight_y >= input_l.dims()[0] {
355 break;
356 }
357
358 let row_offset = weight_y * input_l.stride()[0];
359 for weight_x in start_x..end_x {
360 if weight_x >= input_l.dims()[1] {
361 break;
362 }
363
364 let pos = row_offset + weight_x;
365 let val = input[pos].to_f64() as f32;
366 let abs_val = val.abs();
367 if abs_val > max_abs {
368 max_abs = abs_val;
369 }
370 }
371 }
372
373 let block_scale = if max_abs > 0.0 {
375 max_abs / 448.0
376 } else {
377 1e-12
378 };
379
380 unsafe {
382 *scale_ptr.wrapping_add(y * grid_x + x) = block_scale;
383 }
384
385 for weight_y in start_y..end_y {
387 if weight_y >= input_l.dims()[0] {
388 break;
389 }
390
391 let row_offset = weight_y * input_l.stride()[0];
392 for weight_x in start_x..end_x {
393 if weight_x >= input_l.dims()[1] {
394 break;
395 }
396
397 let pos = row_offset + weight_x;
398 let val = input[pos].to_f64() as f32;
399 let scaled_val = (val / block_scale).clamp(-448.0, 448.0);
400
401 unsafe {
403 *weight_ptr.wrapping_add(pos) = F8E4M3::from_f32(scaled_val);
404 }
405 }
406 }
407 });
408 });
409
410 Ok((weight, scale))
411 }
412}
413
414impl CustomOp1 for Fp8BlockwiseQuantize {
415 fn name(&self) -> &'static str {
416 "fp8-blockwise-quantize"
417 }
418
419 fn cpu_fwd(
420 &self,
421 input_s: &candle_core::CpuStorage,
422 input_l: &candle_core::Layout,
423 ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
424 if input_l.start_offset() != 0 || !input_l.is_contiguous() {
425 candle_core::bail!("Expected input to have start offset 0, continuous");
426 }
427 if input_l.dims().len() != 2 {
428 candle_core::bail!("Expected input to be rank 2");
429 }
430 if self.weight_block_size.len() != 2 {
431 candle_core::bail!("Expected weight_block_size to have length 2");
432 }
433
434 let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
435 let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
436
437 let (weight, scale) = match input_s {
438 CpuStorage::F32(input) => self.dispatch_quant_blockwise(input, input_l)?,
439 CpuStorage::F16(input) => self.dispatch_quant_blockwise(input, input_l)?,
440 CpuStorage::BF16(input) => self.dispatch_quant_blockwise(input, input_l)?,
441 other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
442 };
443
444 let mut packed = Vec::with_capacity(weight.len() + scale.len());
447 packed.extend_from_slice(&weight);
448
449 for &s in &scale {
451 packed.push(F8E4M3::from_f32(s));
452 }
453
454 Ok((
455 CpuStorage::F8E4M3(packed),
456 candle_core::Shape::from_dims(&[
457 input_l.dims()[0] + grid_y,
458 input_l.dims()[1].max(grid_x),
459 ]),
460 ))
461 }
462
463 #[cfg(feature = "cuda")]
464 fn cuda_fwd(
465 &self,
466 input_s: &candle_core::CudaStorage,
467 input_l: &candle_core::Layout,
468 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
469 use candle_core::{backend::BackendStorage, CudaStorage};
470 use half::{bf16, f16};
471
472 use crate::{blockwise_fp8::ffi, utils::slice_ptr};
473
474 if !ffi::HAVE_BLOCKWISE_QUANT_KERNELS {
475 candle_core::bail!("Do not have blockwise FP8 quant kernels.");
476 }
477
478 if input_l.start_offset() != 0 || !input_l.is_contiguous() {
479 candle_core::bail!("Expected input to have start offset 0, continuous");
480 }
481 if input_l.dims().len() != 2 {
482 candle_core::bail!("Expected input to be rank 2");
483 }
484 if self.weight_block_size.len() != 2 {
485 candle_core::bail!("Expected weight_block_size to have length 2");
486 }
487
488 let dev = input_s.device();
489
490 let weight_height = input_l.dim(0)? as i32;
491 let weight_block_size_y = self.weight_block_size[0] as i32;
492 let weight_width = input_l.dim(1)? as i32;
493 let weight_block_size_x = self.weight_block_size[1] as i32;
494 let weight_row_stride = input_l.stride()[0] as i32;
495
496 let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
497 let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
498 let scale_stride = grid_x as i32;
499
500 let weight_output = dev.alloc_zeros::<F8E4M3>(input_l.shape().elem_count())?;
502 let scale_output = dev.alloc_zeros::<f32>(grid_y * grid_x)?;
503
504 let (weight_ptr, weight_guard) = slice_ptr(&weight_output, 0);
505 let (scale_ptr, scale_guard) = slice_ptr(&scale_output, 0);
506
507 match input_s.dtype() {
508 DType::F32 => {
509 let (input, _input_guard) =
510 slice_ptr(input_s.as_cuda_slice::<f32>()?, input_l.start_offset());
511 unsafe {
512 ffi::launch_quant_fp8_blockwise_kernel_f32(
513 input as *const _,
514 weight_ptr as *mut _,
515 scale_ptr as *mut _,
516 weight_height,
517 weight_width,
518 weight_row_stride,
519 scale_stride,
520 weight_block_size_y,
521 weight_block_size_x,
522 dev.cuda_stream().cu_stream(),
523 )
524 };
525 }
526 DType::F16 => {
527 let (input, _input_guard) =
528 slice_ptr(input_s.as_cuda_slice::<f16>()?, input_l.start_offset());
529 unsafe {
530 ffi::launch_quant_fp8_blockwise_kernel_f16(
531 input as *const _,
532 weight_ptr as *mut _,
533 scale_ptr as *mut _,
534 weight_height,
535 weight_width,
536 weight_row_stride,
537 scale_stride,
538 weight_block_size_y,
539 weight_block_size_x,
540 dev.cuda_stream().cu_stream(),
541 )
542 };
543 }
544 DType::BF16 => {
545 let (input, _input_guard) =
546 slice_ptr(input_s.as_cuda_slice::<bf16>()?, input_l.start_offset());
547 unsafe {
548 ffi::launch_quant_fp8_blockwise_kernel_bf16(
549 input as *const _,
550 weight_ptr as *mut _,
551 scale_ptr as *mut _,
552 weight_height,
553 weight_width,
554 weight_row_stride,
555 scale_stride,
556 weight_block_size_y,
557 weight_block_size_x,
558 dev.cuda_stream().cu_stream(),
559 )
560 };
561 }
562 other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
563 }
564
565 drop(weight_guard);
566 drop(scale_guard);
567
568 let res = CudaStorage::wrap_cuda_slice(weight_output, input_s.device().clone());
570 Ok((res, input_l.shape().clone()))
571 }
572
573 #[cfg(feature = "metal")]
574 fn metal_fwd(
575 &self,
576 _input_s: &candle_core::MetalStorage,
577 _input_l: &candle_core::Layout,
578 ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
579 candle_core::bail!("FP8 blockwise quantization not yet implemented for Metal");
580 }
581}
582
583pub fn fp8_blockwise_quantize(
589 #[allow(unused_variables)] input: &Tensor,
590 #[allow(unused_variables)] weight_block_size: Vec<usize>,
591) -> Result<(Tensor, Tensor)> {
592 #[cfg(feature = "cuda")]
595 {
596 use candle_core::{CudaStorage, Device, Storage};
597 use half::{bf16, f16};
598
599 use crate::{blockwise_fp8::ffi, utils::slice_ptr};
600
601 if !matches!(input.device(), Device::Cuda(_)) {
602 candle_core::bail!("FP8 blockwise quantization only supported on CUDA for now");
603 }
604
605 if !ffi::HAVE_BLOCKWISE_QUANT_KERNELS {
606 candle_core::bail!("Do not have blockwise FP8 quant kernels.");
607 }
608
609 let input_l = input.layout();
610 if input_l.start_offset() != 0 || !input_l.is_contiguous() {
611 candle_core::bail!("Expected input to have start offset 0, continuous");
612 }
613 if input.dims().len() != 2 {
614 candle_core::bail!("Expected input to be rank 2");
615 }
616 if weight_block_size.len() != 2 {
617 candle_core::bail!("Expected weight_block_size to have length 2");
618 }
619
620 let dev = match input.device() {
621 Device::Cuda(dev) => dev,
622 _ => unreachable!(),
623 };
624
625 let weight_height = input.dim(0)? as i32;
626 let weight_block_size_y = weight_block_size[0] as i32;
627 let weight_width = input.dim(1)? as i32;
628 let weight_block_size_x = weight_block_size[1] as i32;
629 let weight_row_stride = input_l.stride()[0] as i32;
630
631 let grid_y = input.dim(0)?.div_ceil(weight_block_size[0]);
632 let grid_x = input.dim(1)?.div_ceil(weight_block_size[1]);
633 let scale_stride = grid_x as i32;
634
635 let weight_output = dev.alloc_zeros::<F8E4M3>(input.shape().elem_count())?;
637 let scale_output = dev.alloc_zeros::<f32>(grid_y * grid_x)?;
638
639 let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
640 let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
641
642 match input.dtype() {
643 DType::F32 => {
644 let input_storage = input.storage_and_layout().0;
645 let input_s = match &*input_storage {
646 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
647 _ => candle_core::bail!("Expected CUDA storage"),
648 };
649 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
650 unsafe {
651 ffi::launch_quant_fp8_blockwise_kernel_f32(
652 input_ptr as *const _,
653 weight_ptr as *mut _,
654 scale_ptr as *mut _,
655 weight_height,
656 weight_width,
657 weight_row_stride,
658 scale_stride,
659 weight_block_size_y,
660 weight_block_size_x,
661 dev.cuda_stream().cu_stream(),
662 )
663 };
664 }
665 DType::F16 => {
666 let input_storage = input.storage_and_layout().0;
667 let input_s = match &*input_storage {
668 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
669 _ => candle_core::bail!("Expected CUDA storage"),
670 };
671 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
672 unsafe {
673 ffi::launch_quant_fp8_blockwise_kernel_f16(
674 input_ptr as *const _,
675 weight_ptr as *mut _,
676 scale_ptr as *mut _,
677 weight_height,
678 weight_width,
679 weight_row_stride,
680 scale_stride,
681 weight_block_size_y,
682 weight_block_size_x,
683 dev.cuda_stream().cu_stream(),
684 )
685 };
686 }
687 DType::BF16 => {
688 let input_storage = input.storage_and_layout().0;
689 let input_s = match &*input_storage {
690 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
691 _ => candle_core::bail!("Expected CUDA storage"),
692 };
693 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
694 unsafe {
695 ffi::launch_quant_fp8_blockwise_kernel_bf16(
696 input_ptr as *const _,
697 weight_ptr as *mut _,
698 scale_ptr as *mut _,
699 weight_height,
700 weight_width,
701 weight_row_stride,
702 scale_stride,
703 weight_block_size_y,
704 weight_block_size_x,
705 dev.cuda_stream().cu_stream(),
706 )
707 };
708 }
709 other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
710 }
711
712 drop(_weight_guard);
714 drop(_scale_guard);
715
716 let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
718 let weight = Tensor::from((Storage::Cuda(weight_storage), input.shape().clone()));
719
720 let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
722 let scale = Tensor::from((
723 Storage::Cuda(scale_storage),
724 candle_core::Shape::from_dims(&[grid_y, grid_x]),
725 ));
726
727 Ok((weight, scale))
728 }
729
730 #[cfg(not(feature = "cuda"))]
731 {
732 candle_core::bail!("FP8 blockwise quantization requires CUDA feature");
733 }
734}
735
736#[cfg(feature = "cuda")]
743pub fn fp8_blockwise_matmul(
744 input: &Tensor,
745 weight: &Tensor,
746 scales: &Tensor,
747 weight_block_size: &[usize],
748) -> Result<Tensor> {
749 use candle_core::{CudaStorage, Device, Storage};
750 use half::{bf16, f16};
751
752 use crate::{blockwise_fp8::ffi, utils::slice_ptr};
753
754 if !ffi::HAVE_BLOCKWISE_GEMM_KERNELS {
755 candle_core::bail!("Do not have blockwise FP8 GEMM kernels.");
756 }
757
758 if !matches!(input.device(), Device::Cuda(_)) {
759 candle_core::bail!("FP8 blockwise matmul only supported on CUDA");
760 }
761
762 let input = input.contiguous()?;
763 let weight = weight.contiguous()?;
764 let scales = scales.contiguous()?;
765
766 if input.dims().len() != 2 {
767 candle_core::bail!("Expected input to be rank 2, got {:?}", input.dims());
768 }
769 if weight.dims().len() != 2 {
770 candle_core::bail!("Expected weight to be rank 2, got {:?}", weight.dims());
771 }
772 if weight.dtype() != DType::F8E4M3 {
773 candle_core::bail!("Expected FP8 weight, got {:?}", weight.dtype());
774 }
775
776 let m = input.dim(0)? as i32;
777 let k = input.dim(1)? as i32;
778 let n = weight.dim(0)? as i32;
779
780 if weight.dim(1)? as i32 != k {
781 candle_core::bail!(
782 "Weight K dimension {} doesn't match input K dimension {}",
783 weight.dim(1)?,
784 k
785 );
786 }
787
788 let dev = match input.device() {
789 Device::Cuda(dev) => dev,
790 _ => unreachable!(),
791 };
792
793 let block_size_y = weight_block_size[0] as i32;
794 let block_size_x = weight_block_size[1] as i32;
795 let scale_row_stride = scales.dim(1)? as i32;
796
797 let input_l = input.layout();
798 let weight_l = weight.layout();
799 let scales_l = scales.layout();
800
801 let input_storage = input.storage_and_layout().0;
802 let weight_storage = weight.storage_and_layout().0;
803 let scales_storage = scales.storage_and_layout().0;
804
805 let weight_s = match &*weight_storage {
806 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<F8E4M3>()?,
807 _ => candle_core::bail!("Expected CUDA storage for weight"),
808 };
809 let scales_s = match &*scales_storage {
810 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
811 _ => candle_core::bail!("Expected CUDA storage for scales"),
812 };
813
814 let (weight_ptr, _weight_guard) = slice_ptr(weight_s, weight_l.start_offset());
815 let (scales_ptr, _scales_guard) = slice_ptr(scales_s, scales_l.start_offset());
816
817 match input.dtype() {
818 DType::F16 => {
819 let output = dev.alloc_zeros::<f16>((m * n) as usize)?;
820
821 let input_s = match &*input_storage {
822 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
823 _ => candle_core::bail!("Expected CUDA storage for input"),
824 };
825
826 {
827 let (output_ptr, _output_guard) = slice_ptr(&output, 0);
828 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
829
830 unsafe {
831 ffi::launch_fp8_matmul_f16(
832 input_ptr as *const _,
833 weight_ptr as *const _,
834 scales_ptr as *const _,
835 output_ptr as *mut _,
836 m,
837 n,
838 k,
839 scale_row_stride,
840 block_size_y,
841 block_size_x,
842 dev.cuda_stream().cu_stream(),
843 )
844 };
845 }
846
847 let output_storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
848 Ok(Tensor::from((
849 Storage::Cuda(output_storage),
850 candle_core::Shape::from_dims(&[m as usize, n as usize]),
851 )))
852 }
853 DType::BF16 => {
854 let output = dev.alloc_zeros::<bf16>((m * n) as usize)?;
855
856 let input_s = match &*input_storage {
857 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
858 _ => candle_core::bail!("Expected CUDA storage for input"),
859 };
860
861 {
862 let (output_ptr, _output_guard) = slice_ptr(&output, 0);
863 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
864
865 unsafe {
866 ffi::launch_fp8_matmul_bf16(
867 input_ptr as *const _,
868 weight_ptr as *const _,
869 scales_ptr as *const _,
870 output_ptr as *mut _,
871 m,
872 n,
873 k,
874 scale_row_stride,
875 block_size_y,
876 block_size_x,
877 dev.cuda_stream().cu_stream(),
878 )
879 };
880 }
881
882 let output_storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
883 Ok(Tensor::from((
884 Storage::Cuda(output_storage),
885 candle_core::Shape::from_dims(&[m as usize, n as usize]),
886 )))
887 }
888 other => candle_core::bail!("Unsupported input dtype for FP8 matmul: {:?}", other),
889 }
890}
891
892#[cfg(feature = "cuda")]
900pub fn fp8_indexed_moe_gemm(
901 input: &Tensor,
902 weights: &Tensor,
903 scales: &Tensor,
904 indices: &Tensor,
905 weight_block_size: &[usize],
906) -> Result<Tensor> {
907 use candle_core::{CudaStorage, Device, Storage};
908 use half::{bf16, f16};
909
910 use crate::{blockwise_fp8::ffi, utils::slice_ptr};
911
912 if !ffi::HAVE_BLOCKWISE_GEMM_KERNELS {
913 candle_core::bail!("Do not have blockwise FP8 GEMM kernels.");
914 }
915
916 if !matches!(input.device(), Device::Cuda(_)) {
917 candle_core::bail!("FP8 indexed MoE GEMM only supported on CUDA");
918 }
919
920 let input = input.contiguous()?;
921 let weights = weights.contiguous()?;
922 let scales = scales.contiguous()?;
923 let indices = indices.contiguous()?;
924
925 let (num_tokens, input_has_topk_dim, k) = if input.dims().len() == 3 {
928 let dims = input.dims3()?;
929 (dims.0, dims.1 > 1, dims.2)
930 } else if input.dims().len() == 2 {
931 let dims = input.dims2()?;
932 (dims.0, false, dims.1)
933 } else {
934 candle_core::bail!("Expected input to be rank 2 or 3, got {:?}", input.dims());
935 };
936
937 let (indices_tokens, topk) = indices.dims2()?;
939 if indices_tokens != num_tokens {
940 candle_core::bail!(
941 "Indices num_tokens {} doesn't match input num_tokens {}",
942 indices_tokens,
943 num_tokens
944 );
945 }
946
947 if weights.dims().len() != 3 {
949 candle_core::bail!("Expected weights to be rank 3, got {:?}", weights.dims());
950 }
951 let (num_experts, n, weight_k) = weights.dims3()?;
952 if weight_k != k {
953 candle_core::bail!(
954 "Weights K dimension {} doesn't match input K dimension {}",
955 weight_k,
956 k
957 );
958 }
959
960 if weights.dtype() != DType::F8E4M3 {
961 candle_core::bail!("Expected FP8 weights, got {:?}", weights.dtype());
962 }
963
964 let dev = match input.device() {
965 Device::Cuda(dev) => dev,
966 _ => unreachable!(),
967 };
968
969 let block_size_y = weight_block_size[0] as i32;
970 let block_size_x = weight_block_size[1] as i32;
971
972 let scale_row_stride = scales.dim(2)? as i32; let input_l = input.layout();
976 let weights_l = weights.layout();
977 let scales_l = scales.layout();
978 let indices_l = indices.layout();
979
980 let input_storage = input.storage_and_layout().0;
981 let weights_storage = weights.storage_and_layout().0;
982 let scales_storage = scales.storage_and_layout().0;
983 let indices_storage = indices.storage_and_layout().0;
984
985 let weights_s = match &*weights_storage {
986 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<F8E4M3>()?,
987 _ => candle_core::bail!("Expected CUDA storage for weights"),
988 };
989 let scales_s = match &*scales_storage {
990 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
991 _ => candle_core::bail!("Expected CUDA storage for scales"),
992 };
993 let indices_s = match &*indices_storage {
994 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<u32>()?,
995 _ => candle_core::bail!("Expected CUDA storage for indices"),
996 };
997
998 let (weights_ptr, _weights_guard) = slice_ptr(weights_s, weights_l.start_offset());
999 let (scales_ptr, _scales_guard) = slice_ptr(scales_s, scales_l.start_offset());
1000 let (indices_ptr, _indices_guard) = slice_ptr(indices_s, indices_l.start_offset());
1001
1002 match input.dtype() {
1003 DType::F16 => {
1004 let output = dev.alloc_zeros::<f16>(num_tokens * topk * n)?;
1005
1006 let input_s = match &*input_storage {
1007 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
1008 _ => candle_core::bail!("Expected CUDA storage for input"),
1009 };
1010
1011 {
1012 let (output_ptr, _output_guard) = slice_ptr(&output, 0);
1013 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
1014
1015 unsafe {
1016 ffi::launch_fp8_indexed_moe_gemm_f16(
1017 input_ptr as *const _,
1018 weights_ptr as *const _,
1019 scales_ptr as *const _,
1020 indices_ptr as *const _,
1021 output_ptr as *mut _,
1022 num_tokens as i32,
1023 topk as i32,
1024 num_experts as i32,
1025 n as i32,
1026 k as i32,
1027 scale_row_stride,
1028 block_size_y,
1029 block_size_x,
1030 input_has_topk_dim,
1031 dev.cuda_stream().cu_stream(),
1032 )
1033 };
1034 }
1035
1036 let output_storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
1037 Ok(Tensor::from((
1038 Storage::Cuda(output_storage),
1039 candle_core::Shape::from_dims(&[num_tokens, topk, n]),
1040 )))
1041 }
1042 DType::BF16 => {
1043 let output = dev.alloc_zeros::<bf16>(num_tokens * topk * n)?;
1044
1045 let input_s = match &*input_storage {
1046 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
1047 _ => candle_core::bail!("Expected CUDA storage for input"),
1048 };
1049
1050 {
1051 let (output_ptr, _output_guard) = slice_ptr(&output, 0);
1052 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
1053
1054 unsafe {
1055 ffi::launch_fp8_indexed_moe_gemm_bf16(
1056 input_ptr as *const _,
1057 weights_ptr as *const _,
1058 scales_ptr as *const _,
1059 indices_ptr as *const _,
1060 output_ptr as *mut _,
1061 num_tokens as i32,
1062 topk as i32,
1063 num_experts as i32,
1064 n as i32,
1065 k as i32,
1066 scale_row_stride,
1067 block_size_y,
1068 block_size_x,
1069 input_has_topk_dim,
1070 dev.cuda_stream().cu_stream(),
1071 )
1072 };
1073 }
1074
1075 let output_storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
1076 Ok(Tensor::from((
1077 Storage::Cuda(output_storage),
1078 candle_core::Shape::from_dims(&[num_tokens, topk, n]),
1079 )))
1080 }
1081 other => candle_core::bail!(
1082 "Unsupported input dtype for FP8 indexed MoE GEMM: {:?}",
1083 other
1084 ),
1085 }
1086}
1087
1088#[cfg(test)]
1089#[allow(unused_imports)]
1090mod tests {
1091 use candle_core::{DType, Device, Result, Tensor};
1092 use candle_nn::{Linear, Module};
1093 use half::bf16;
1094 use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
1095
1096 use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors};
1097
1098 #[test]
1099 fn test_fp8_blockwise_dequant() -> Result<()> {
1100 let dev = &Device::Cpu;
1101 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
1102 let weight_block_size = vec![2, 2];
1103 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1104
1105 let dequant =
1106 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
1107
1108 let res = dequant.to_vec2::<f32>()?;
1109 assert_eq!(
1110 res,
1111 vec![
1112 vec![0., 0., 1., 1., 2.],
1113 vec![0., 0., 1., 1., 2.],
1114 vec![3., 3., 4., 4., 5.],
1115 vec![3., 3., 4., 4., 5.],
1116 vec![6., 6., 7., 7., 8.],
1117 ]
1118 );
1119
1120 Ok(())
1121 }
1122
1123 #[cfg(feature = "cuda")]
1124 #[test]
1125 fn test_fp8_blockwise_dequant_cuda() -> Result<()> {
1126 let truth = {
1127 let dev = &Device::Cpu;
1128 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
1129 let weight_block_size = vec![2, 2];
1130 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1131
1132 let dequant =
1133 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
1134
1135 dequant.to_vec2::<f32>()?
1136 };
1137 let test = {
1138 let dev = &Device::new_cuda(0)?;
1139 let weight_cpu = Tensor::ones((5, 5), DType::F8E4M3, &Device::Cpu)?;
1141 let weight = weight_cpu.to_device(dev)?;
1142 let weight_block_size = vec![2, 2];
1143 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1144
1145 let dequant =
1146 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
1147
1148 dequant.to_vec2::<f32>()?
1149 };
1150
1151 assert_eq!(test, truth);
1152 assert_eq!(
1153 test,
1154 vec![
1155 vec![0., 0., 1., 1., 2.],
1156 vec![0., 0., 1., 1., 2.],
1157 vec![3., 3., 4., 4., 5.],
1158 vec![3., 3., 4., 4., 5.],
1159 vec![6., 6., 7., 7., 8.],
1160 ]
1161 );
1162
1163 Ok(())
1164 }
1165
1166 #[test]
1167 fn test_fp8_blockwise_dequant_bf16() -> Result<()> {
1168 let dev = &Device::Cpu;
1169 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
1170 let weight_block_size = vec![2, 2];
1171 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1172
1173 let dequant =
1174 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::BF16)?;
1175
1176 let res = dequant.to_vec2::<bf16>()?;
1177 assert_eq!(
1178 res,
1179 vec![
1180 vec![
1181 bf16::from_f32(0.),
1182 bf16::from_f32(0.),
1183 bf16::from_f32(1.),
1184 bf16::from_f32(1.),
1185 bf16::from_f32(2.)
1186 ],
1187 vec![
1188 bf16::from_f32(0.),
1189 bf16::from_f32(0.),
1190 bf16::from_f32(1.),
1191 bf16::from_f32(1.),
1192 bf16::from_f32(2.)
1193 ],
1194 vec![
1195 bf16::from_f32(3.),
1196 bf16::from_f32(3.),
1197 bf16::from_f32(4.),
1198 bf16::from_f32(4.),
1199 bf16::from_f32(5.)
1200 ],
1201 vec![
1202 bf16::from_f32(3.),
1203 bf16::from_f32(3.),
1204 bf16::from_f32(4.),
1205 bf16::from_f32(4.),
1206 bf16::from_f32(5.)
1207 ],
1208 vec![
1209 bf16::from_f32(6.),
1210 bf16::from_f32(6.),
1211 bf16::from_f32(7.),
1212 bf16::from_f32(7.),
1213 bf16::from_f32(8.)
1214 ],
1215 ]
1216 );
1217
1218 Ok(())
1219 }
1220
1221 #[cfg(feature = "cuda")]
1222 #[test]
1223 fn test_fp8_blockwise_dequant_cuda_bf16() -> Result<()> {
1224 let truth = {
1225 let dev = &Device::Cpu;
1226 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
1227 let weight_block_size = vec![2, 2];
1228 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1229
1230 let dequant = ops::fp8_blockwise_dequantize(
1231 &weight,
1232 &inv_scales,
1233 weight_block_size,
1234 DType::BF16,
1235 )?;
1236
1237 dequant.to_vec2::<bf16>()?
1238 };
1239 let test = {
1240 let dev = &Device::new_cuda(0)?;
1241 let weight_cpu = Tensor::ones((5, 5), DType::F8E4M3, &Device::Cpu)?;
1243 let weight = weight_cpu.to_device(dev)?;
1244 let weight_block_size = vec![2, 2];
1245 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1246
1247 let dequant = ops::fp8_blockwise_dequantize(
1248 &weight,
1249 &inv_scales,
1250 weight_block_size,
1251 DType::BF16,
1252 )?;
1253
1254 dequant.to_vec2::<bf16>()?
1255 };
1256
1257 assert_eq!(test, truth);
1258 assert_eq!(
1259 test,
1260 vec![
1261 vec![
1262 bf16::from_f32(0.),
1263 bf16::from_f32(0.),
1264 bf16::from_f32(1.),
1265 bf16::from_f32(1.),
1266 bf16::from_f32(2.)
1267 ],
1268 vec![
1269 bf16::from_f32(0.),
1270 bf16::from_f32(0.),
1271 bf16::from_f32(1.),
1272 bf16::from_f32(1.),
1273 bf16::from_f32(2.)
1274 ],
1275 vec![
1276 bf16::from_f32(3.),
1277 bf16::from_f32(3.),
1278 bf16::from_f32(4.),
1279 bf16::from_f32(4.),
1280 bf16::from_f32(5.)
1281 ],
1282 vec![
1283 bf16::from_f32(3.),
1284 bf16::from_f32(3.),
1285 bf16::from_f32(4.),
1286 bf16::from_f32(4.),
1287 bf16::from_f32(5.)
1288 ],
1289 vec![
1290 bf16::from_f32(6.),
1291 bf16::from_f32(6.),
1292 bf16::from_f32(7.),
1293 bf16::from_f32(7.),
1294 bf16::from_f32(8.)
1295 ],
1296 ]
1297 );
1298
1299 Ok(())
1300 }
1301
1302 #[cfg(feature = "cuda")]
1303 #[test]
1304 fn test_fp8_blockwise_quant_dequant_roundtrip() -> Result<()> {
1305 let dev = &Device::new_cuda(0)?;
1306
1307 let input = Tensor::randn(0f32, 2f32, (8, 8), dev)?;
1309 let weight_block_size = vec![4, 4];
1310
1311 let (quantized, scales) = ops::fp8_blockwise_quantize(&input, weight_block_size.clone())?;
1313
1314 assert_eq!(quantized.shape(), input.shape());
1316 assert_eq!(scales.dims2()?, (2, 2)); let dequantized =
1320 ops::fp8_blockwise_dequantize(&quantized, &scales, weight_block_size, input.dtype())?;
1321
1322 assert_eq!(dequantized.shape(), input.shape());
1324
1325 let input_vec = input.to_vec2::<f32>()?;
1328 let dequant_vec = dequantized.to_vec2::<f32>()?;
1329
1330 let mut max_error = 0f32;
1331 for (row_in, row_out) in input_vec.iter().zip(dequant_vec.iter()) {
1332 for (val_in, val_out) in row_in.iter().zip(row_out.iter()) {
1333 let error = (val_in - val_out).abs();
1334 max_error = max_error.max(error);
1335 }
1336 }
1337
1338 assert!(max_error < 0.16, "Max error {} is too large", max_error);
1341
1342 Ok(())
1343 }
1344
1345 #[cfg(feature = "cuda")]
1346 #[test]
1347 fn test_blockwise_fp8_gemm() -> Result<()> {
1348 let dev = Device::cuda_if_available(0)?;
1349
1350 let api = ApiBuilder::new().with_progress(true).build().unwrap();
1351 let api = api.repo(Repo::with_revision(
1352 "EricB/mistralrs_tests".to_string(),
1353 RepoType::Model,
1354 "main".to_string(),
1355 ));
1356
1357 let filename = api.get("test_fp8.safetensors").unwrap();
1358 let vb = unsafe { MmapedSafetensors::new(filename)? };
1359
1360 let weight = vb.load("weight", &dev, None)?;
1361 assert_eq!((7168, 2048), weight.dims2()?);
1362 assert_eq!(DType::F8E4M3, weight.dtype());
1363
1364 let scale = vb.load("scale", &dev, None)?;
1365 assert_eq!((56, 16), scale.dims2()?);
1366 assert_eq!(DType::F32, scale.dtype());
1367
1368 let weight_block_size = vec![128, 128];
1369
1370 let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?;
1372
1373 let truth = {
1374 let weight_dq =
1375 ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?;
1376
1377 let lin_dq = Linear::new(weight_dq, None);
1378 lin_dq.forward(&xs)?
1379 };
1380
1381 assert_eq!((32, 7168), truth.dims2()?);
1383
1384 Ok(())
1385 }
1386}