mistralrs_quant/blockwise_fp8/
ops.rs

1#[cfg(feature = "cuda")]
2use candle_core::from_storage_no_op;
3use candle_core::{CpuStorage, CustomOp1, CustomOp2, DType, Result, Tensor, WithDType};
4use float8::F8E4M3;
5use rayon::iter::{IntoParallelIterator, ParallelIterator};
6
7struct Fp8BlockwiseDequantize {
8    weight_block_size: Vec<usize>,
9    out_ty: DType,
10}
11
12impl Fp8BlockwiseDequantize {
13    fn dispatch_dequant_blockwise<T: WithDType>(
14        &self,
15        weight: &[F8E4M3],
16        scale: &[f32],
17        weight_l: &candle_core::Layout,
18        scale_l: &candle_core::Layout,
19    ) -> candle_core::Result<Vec<T>> {
20        let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]);
21        let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]);
22
23        let res = vec![T::zero(); weight.len()];
24
25        (0..grid_y).into_par_iter().for_each(|y| {
26            (0..grid_x).into_par_iter().for_each(|x| {
27                let res_ptr = res.as_ptr() as *mut T;
28
29                let scale = scale[y * scale_l.stride()[0] + x];
30
31                let start_y = y * self.weight_block_size[0];
32                let end_y = start_y + self.weight_block_size[0];
33
34                let start_x = x * self.weight_block_size[1];
35                let end_x = start_x + self.weight_block_size[1];
36
37                for weight_y in start_y..end_y {
38                    if weight_y >= weight_l.dims()[0] {
39                        break;
40                    }
41
42                    let row_offset = weight_y * weight_l.stride()[0];
43                    for weight_x in start_x..end_x {
44                        if weight_x >= weight_l.dims()[1] {
45                            break;
46                        }
47
48                        let weight_pos = row_offset + weight_x;
49
50                        // SAFETY: We know each thread will only update indepedant values!
51                        unsafe {
52                            *res_ptr.wrapping_add(weight_pos) =
53                                T::from_f64((weight[weight_pos].to_f32() * scale) as f64);
54                        }
55                    }
56                }
57            });
58        });
59
60        Ok(res)
61    }
62}
63
64impl CustomOp2 for Fp8BlockwiseDequantize {
65    fn name(&self) -> &'static str {
66        "fp8-blockwise-dequantize"
67    }
68
69    fn cpu_fwd(
70        &self,
71        scale_s: &candle_core::CpuStorage,
72        scale_l: &candle_core::Layout,
73        weight_s: &candle_core::CpuStorage,
74        weight_l: &candle_core::Layout,
75    ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
76        let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
77            candle_core::bail!("Expected F8E4M3 weight!");
78        };
79        let candle_core::CpuStorage::F32(scale) = scale_s else {
80            candle_core::bail!("Expected F8E4M3 weight!");
81        };
82        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
83            candle_core::bail!("Expected weight to have start offset 0, continuous");
84        }
85        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
86            candle_core::bail!("Expected scales to have start offset 0, continuous");
87        }
88        if weight_l.dims().len() != 2 {
89            candle_core::bail!("Expected weight to be rank 2");
90        }
91        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
92            candle_core::bail!("Expected scale to be rank 2");
93        }
94
95        match self.out_ty {
96            DType::F32 => Ok((
97                CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
98                weight_l.shape().clone(),
99            )),
100            DType::BF16 => Ok((
101                CpuStorage::BF16(
102                    self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?,
103                ),
104                weight_l.shape().clone(),
105            )),
106            DType::F16 => Ok((
107                CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
108                weight_l.shape().clone(),
109            )),
110            other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
111        }
112    }
113
114    #[cfg(feature = "cuda")]
115    fn cuda_fwd(
116        &self,
117        scale_s: &candle_core::CudaStorage,
118        scale_l: &candle_core::Layout,
119        weight_s: &candle_core::CudaStorage,
120        weight_l: &candle_core::Layout,
121    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
122        use candle_core::{backend::BackendStorage, CudaStorage};
123        use half::{bf16, f16};
124
125        use crate::{blockwise_fp8::ffi, utils::slice_ptr};
126
127        if !ffi::HAVE_BLOCKWISE_DEQUANT_KERNELS {
128            candle_core::bail!("Do not have blockwise FP8 dequant kernels.");
129        }
130
131        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
132            candle_core::bail!("Expected weight to have start offset 0, continuous");
133        }
134        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
135            candle_core::bail!("Expected scales to have start offset 0, continuous");
136        }
137        if weight_l.dims().len() != 2 {
138            candle_core::bail!("Expected weight to be rank 2");
139        }
140        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
141            candle_core::bail!("Expected scale to be rank 2");
142        }
143
144        let dev = weight_s.device();
145
146        let (weight, _weight_guard) =
147            slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
148        let (scale, _scale_guard) =
149            slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
150
151        let weight_height = weight_l.dim(0)? as i32;
152        let weight_block_size_x = self.weight_block_size[0] as i32;
153        let weight_width = weight_l.dim(1)? as i32;
154        let weight_block_size_y = self.weight_block_size[1] as i32;
155        let scale_stride = scale_l.stride()[0] as i32;
156        let weight_row_stride = weight_l.stride()[0] as i32;
157
158        let res = match self.out_ty {
159            DType::F32 => {
160                let output = weight_s
161                    .device()
162                    .alloc_zeros::<f32>(weight_l.shape().elem_count())?;
163                let (output_ptr, output_guard) = slice_ptr(&output, 0);
164                unsafe {
165                    ffi::launch_dequant_fp8_blockwise_kernel_f32(
166                        weight as *const _,
167                        scale as *const _,
168                        output_ptr as *mut _,
169                        weight_height,
170                        weight_width,
171                        weight_row_stride,
172                        scale_stride,
173                        weight_block_size_y,
174                        weight_block_size_x,
175                        dev.cuda_stream().cu_stream(),
176                    )
177                };
178                drop(output_guard);
179                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
180            }
181            DType::F16 => {
182                let output = weight_s
183                    .device()
184                    .alloc_zeros::<f16>(weight_l.shape().elem_count())?;
185                let (output_ptr, output_guard) = slice_ptr(&output, 0);
186                unsafe {
187                    ffi::launch_dequant_fp8_blockwise_kernel_f16(
188                        weight as *const _,
189                        scale as *const _,
190                        output_ptr as *mut _,
191                        weight_height,
192                        weight_width,
193                        weight_row_stride,
194                        scale_stride,
195                        weight_block_size_y,
196                        weight_block_size_x,
197                        dev.cuda_stream().cu_stream(),
198                    )
199                };
200                drop(output_guard);
201                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
202            }
203            DType::BF16 => {
204                let output = weight_s
205                    .device()
206                    .alloc_zeros::<bf16>(weight_l.shape().elem_count())?;
207                let (output_ptr, output_guard) = slice_ptr(&output, 0);
208                unsafe {
209                    ffi::launch_dequant_fp8_blockwise_kernel_bf16(
210                        weight as *const _,
211                        scale as *const _,
212                        output_ptr as *mut _,
213                        weight_height,
214                        weight_width,
215                        weight_row_stride,
216                        scale_stride,
217                        weight_block_size_y,
218                        weight_block_size_x,
219                        dev.cuda_stream().cu_stream(),
220                    )
221                };
222                drop(output_guard);
223                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
224            }
225            other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
226        };
227
228        Ok((res, weight_l.shape().clone()))
229    }
230
231    #[cfg(feature = "metal")]
232    fn metal_fwd(
233        &self,
234        scale_s: &candle_core::MetalStorage,
235        scale_l: &candle_core::Layout,
236        weight_s: &candle_core::MetalStorage,
237        weight_l: &candle_core::Layout,
238    ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
239        use candle_core::backend::BackendStorage;
240
241        if weight_l.start_offset() != 0
242            || !weight_l.is_contiguous()
243            || weight_s.dtype() != DType::F8E4M3
244        {
245            candle_core::bail!("Expected f8e4m3 weight to have start offset 0, continuous");
246        }
247        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() || scale_s.dtype() != DType::F32
248        {
249            candle_core::bail!("Expected f32 scales to have start offset 0, continuous");
250        }
251        if weight_l.dims().len() != 2 {
252            candle_core::bail!("Expected weight to be rank 2");
253        }
254        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
255            candle_core::bail!("Expected scale to be rank 2");
256        }
257
258        let command_buffer = weight_s.device().command_buffer()?;
259        command_buffer.set_label("dequant-blockwise-fp8");
260
261        let device = weight_s.device();
262
263        let out_shape = weight_l.shape().clone();
264
265        let output = device.new_buffer(
266            out_shape.elem_count(),
267            weight_s.dtype(),
268            "dequant-blockwise-fp8",
269        )?;
270
271        let weight_height = weight_l.dim(0)? as u32;
272        let weight_block_size_x = self.weight_block_size[0] as u32;
273        let weight_width = weight_l.dim(1)? as u32;
274        let weight_block_size_y = self.weight_block_size[1] as u32;
275        let scale_stride = scale_l.stride()[0] as u32;
276        let weight_row_stride = weight_l.stride()[0] as u32;
277
278        crate::metal_kernels::call_dequant_blockwise_fp8(
279            device.device(),
280            &command_buffer,
281            &crate::metal_kernels::Kernels::new(),
282            self.out_ty,
283            weight_s.buffer(),
284            scale_s.buffer(),
285            &output,
286            weight_height,
287            weight_width,
288            weight_row_stride,
289            scale_stride,
290            weight_block_size_y,
291            weight_block_size_x,
292        )
293        .map_err(candle_core::Error::wrap)?;
294
295        let newstorage = candle_core::MetalStorage::new(
296            output,
297            device.clone(),
298            out_shape.elem_count(),
299            self.out_ty,
300        );
301        Ok((newstorage, out_shape))
302    }
303}
304
305/// FP8 blockwise dequantize.
306/// - Expects weight to be fp8
307/// - Expects inv_scales to be f32
308/// - weight * inv_scale = dequantized
309pub fn fp8_blockwise_dequantize(
310    weight: &Tensor,
311    inv_scales: &Tensor,
312    weight_block_size: Vec<usize>,
313    out_ty: DType,
314) -> Result<Tensor> {
315    inv_scales.apply_op2_no_bwd(
316        weight,
317        &Fp8BlockwiseDequantize {
318            weight_block_size,
319            out_ty,
320        },
321    )
322}
323
324#[allow(dead_code)]
325struct Fp8BlockwiseQuantize {
326    weight_block_size: Vec<usize>,
327}
328
329impl Fp8BlockwiseQuantize {
330    #[allow(dead_code)]
331    fn dispatch_quant_blockwise<T: WithDType>(
332        &self,
333        input: &[T],
334        input_l: &candle_core::Layout,
335    ) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
336        let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
337        let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
338
339        let weight = vec![F8E4M3::from_f32(0.0); input.len()];
340        let scale = vec![0f32; grid_y * grid_x];
341
342        (0..grid_y).into_par_iter().for_each(|y| {
343            (0..grid_x).into_par_iter().for_each(|x| {
344                let weight_ptr = weight.as_ptr() as *mut F8E4M3;
345                let scale_ptr = scale.as_ptr() as *mut f32;
346
347                let start_y = y * self.weight_block_size[0];
348                let end_y = start_y + self.weight_block_size[0];
349
350                let start_x = x * self.weight_block_size[1];
351                let end_x = start_x + self.weight_block_size[1];
352
353                // Find max absolute value in block
354                let mut max_abs = 0f32;
355                for weight_y in start_y..end_y {
356                    if weight_y >= input_l.dims()[0] {
357                        break;
358                    }
359
360                    let row_offset = weight_y * input_l.stride()[0];
361                    for weight_x in start_x..end_x {
362                        if weight_x >= input_l.dims()[1] {
363                            break;
364                        }
365
366                        let pos = row_offset + weight_x;
367                        let val = input[pos].to_f64() as f32;
368                        let abs_val = val.abs();
369                        if abs_val > max_abs {
370                            max_abs = abs_val;
371                        }
372                    }
373                }
374
375                // Calculate scale
376                let block_scale = if max_abs > 0.0 {
377                    max_abs / 448.0
378                } else {
379                    1e-12
380                };
381
382                // SAFETY: We know each thread will only update independent values!
383                unsafe {
384                    *scale_ptr.wrapping_add(y * grid_x + x) = block_scale;
385                }
386
387                // Quantize values
388                for weight_y in start_y..end_y {
389                    if weight_y >= input_l.dims()[0] {
390                        break;
391                    }
392
393                    let row_offset = weight_y * input_l.stride()[0];
394                    for weight_x in start_x..end_x {
395                        if weight_x >= input_l.dims()[1] {
396                            break;
397                        }
398
399                        let pos = row_offset + weight_x;
400                        let val = input[pos].to_f64() as f32;
401                        let scaled_val = (val / block_scale).clamp(-448.0, 448.0);
402
403                        // SAFETY: We know each thread will only update independent values!
404                        unsafe {
405                            *weight_ptr.wrapping_add(pos) = F8E4M3::from_f32(scaled_val);
406                        }
407                    }
408                }
409            });
410        });
411
412        Ok((weight, scale))
413    }
414}
415
416impl CustomOp1 for Fp8BlockwiseQuantize {
417    fn name(&self) -> &'static str {
418        "fp8-blockwise-quantize"
419    }
420
421    fn cpu_fwd(
422        &self,
423        input_s: &candle_core::CpuStorage,
424        input_l: &candle_core::Layout,
425    ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
426        if input_l.start_offset() != 0 || !input_l.is_contiguous() {
427            candle_core::bail!("Expected input to have start offset 0, continuous");
428        }
429        if input_l.dims().len() != 2 {
430            candle_core::bail!("Expected input to be rank 2");
431        }
432        if self.weight_block_size.len() != 2 {
433            candle_core::bail!("Expected weight_block_size to have length 2");
434        }
435
436        let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
437        let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
438
439        let (weight, scale) = match input_s {
440            CpuStorage::F32(input) => self.dispatch_quant_blockwise(input, input_l)?,
441            CpuStorage::F16(input) => self.dispatch_quant_blockwise(input, input_l)?,
442            CpuStorage::BF16(input) => self.dispatch_quant_blockwise(input, input_l)?,
443            other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
444        };
445
446        // Return both weight and scale tensors packed into a single storage
447        // We'll need to unpack them after the op
448        let mut packed = Vec::with_capacity(weight.len() + scale.len());
449        packed.extend_from_slice(&weight);
450
451        // Convert scale to F8E4M3 for storage (will convert back when unpacking)
452        for &s in &scale {
453            packed.push(F8E4M3::from_f32(s));
454        }
455
456        Ok((
457            CpuStorage::F8E4M3(packed),
458            candle_core::Shape::from_dims(&[
459                input_l.dims()[0] + grid_y,
460                input_l.dims()[1].max(grid_x),
461            ]),
462        ))
463    }
464
465    #[cfg(feature = "cuda")]
466    fn cuda_fwd(
467        &self,
468        input_s: &candle_core::CudaStorage,
469        input_l: &candle_core::Layout,
470    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
471        use candle_core::{backend::BackendStorage, CudaStorage};
472        use half::{bf16, f16};
473
474        use crate::{blockwise_fp8::ffi, utils::slice_ptr};
475
476        if !ffi::HAVE_BLOCKWISE_QUANT_KERNELS {
477            candle_core::bail!("Do not have blockwise FP8 quant kernels.");
478        }
479
480        if input_l.start_offset() != 0 || !input_l.is_contiguous() {
481            candle_core::bail!("Expected input to have start offset 0, continuous");
482        }
483        if input_l.dims().len() != 2 {
484            candle_core::bail!("Expected input to be rank 2");
485        }
486        if self.weight_block_size.len() != 2 {
487            candle_core::bail!("Expected weight_block_size to have length 2");
488        }
489
490        let dev = input_s.device();
491
492        let weight_height = input_l.dim(0)? as i32;
493        let weight_block_size_y = self.weight_block_size[0] as i32;
494        let weight_width = input_l.dim(1)? as i32;
495        let weight_block_size_x = self.weight_block_size[1] as i32;
496        let weight_row_stride = input_l.stride()[0] as i32;
497
498        let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
499        let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
500        let scale_stride = grid_x as i32;
501
502        // Allocate output buffers
503        let weight_output = dev.alloc_zeros::<F8E4M3>(input_l.shape().elem_count())?;
504        let scale_output = dev.alloc_zeros::<f32>(grid_y * grid_x)?;
505
506        let (weight_ptr, weight_guard) = slice_ptr(&weight_output, 0);
507        let (scale_ptr, scale_guard) = slice_ptr(&scale_output, 0);
508
509        match input_s.dtype() {
510            DType::F32 => {
511                let (input, _input_guard) =
512                    slice_ptr(input_s.as_cuda_slice::<f32>()?, input_l.start_offset());
513                unsafe {
514                    ffi::launch_quant_fp8_blockwise_kernel_f32(
515                        input as *const _,
516                        weight_ptr as *mut _,
517                        scale_ptr as *mut _,
518                        weight_height,
519                        weight_width,
520                        weight_row_stride,
521                        scale_stride,
522                        weight_block_size_y,
523                        weight_block_size_x,
524                        dev.cuda_stream().cu_stream(),
525                    )
526                };
527            }
528            DType::F16 => {
529                let (input, _input_guard) =
530                    slice_ptr(input_s.as_cuda_slice::<f16>()?, input_l.start_offset());
531                unsafe {
532                    ffi::launch_quant_fp8_blockwise_kernel_f16(
533                        input as *const _,
534                        weight_ptr as *mut _,
535                        scale_ptr as *mut _,
536                        weight_height,
537                        weight_width,
538                        weight_row_stride,
539                        scale_stride,
540                        weight_block_size_y,
541                        weight_block_size_x,
542                        dev.cuda_stream().cu_stream(),
543                    )
544                };
545            }
546            DType::BF16 => {
547                let (input, _input_guard) =
548                    slice_ptr(input_s.as_cuda_slice::<bf16>()?, input_l.start_offset());
549                unsafe {
550                    ffi::launch_quant_fp8_blockwise_kernel_bf16(
551                        input as *const _,
552                        weight_ptr as *mut _,
553                        scale_ptr as *mut _,
554                        weight_height,
555                        weight_width,
556                        weight_row_stride,
557                        scale_stride,
558                        weight_block_size_y,
559                        weight_block_size_x,
560                        dev.cuda_stream().cu_stream(),
561                    )
562                };
563            }
564            other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
565        }
566
567        drop(weight_guard);
568        drop(scale_guard);
569
570        // Return just the weight tensor - we'll handle scale separately
571        let res = CudaStorage::wrap_cuda_slice(weight_output, input_s.device().clone());
572        Ok((res, input_l.shape().clone()))
573    }
574
575    #[cfg(feature = "metal")]
576    fn metal_fwd(
577        &self,
578        _input_s: &candle_core::MetalStorage,
579        _input_l: &candle_core::Layout,
580    ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
581        candle_core::bail!("FP8 blockwise quantization not yet implemented for Metal");
582    }
583}
584
585/// FP8 blockwise quantize.
586/// - Expects input to be f32, f16, or bf16
587/// - Returns a tuple of (quantized_weight, scales)
588/// - quantized_weight is fp8
589/// - scales is f32
590pub fn fp8_blockwise_quantize(
591    #[allow(unused_variables)] input: &Tensor,
592    #[allow(unused_variables)] weight_block_size: Vec<usize>,
593) -> Result<(Tensor, Tensor)> {
594    // Since CustomOp1 only returns a single tensor, we need a different approach
595    // Let's implement this using the CUDA kernels directly
596    #[cfg(feature = "cuda")]
597    {
598        use candle_core::{CudaStorage, Device, Storage};
599        use half::{bf16, f16};
600
601        use crate::{blockwise_fp8::ffi, utils::slice_ptr};
602
603        if !matches!(input.device(), Device::Cuda(_)) {
604            candle_core::bail!("FP8 blockwise quantization only supported on CUDA for now");
605        }
606
607        if !ffi::HAVE_BLOCKWISE_QUANT_KERNELS {
608            candle_core::bail!("Do not have blockwise FP8 quant kernels.");
609        }
610
611        let input_l = input.layout();
612        if input_l.start_offset() != 0 || !input_l.is_contiguous() {
613            candle_core::bail!("Expected input to have start offset 0, continuous");
614        }
615        if input.dims().len() != 2 {
616            candle_core::bail!("Expected input to be rank 2");
617        }
618        if weight_block_size.len() != 2 {
619            candle_core::bail!("Expected weight_block_size to have length 2");
620        }
621
622        let dev = match input.device() {
623            Device::Cuda(dev) => dev,
624            _ => unreachable!(),
625        };
626
627        let weight_height = input.dim(0)? as i32;
628        let weight_block_size_y = weight_block_size[0] as i32;
629        let weight_width = input.dim(1)? as i32;
630        let weight_block_size_x = weight_block_size[1] as i32;
631        let weight_row_stride = input_l.stride()[0] as i32;
632
633        let grid_y = input.dim(0)?.div_ceil(weight_block_size[0]);
634        let grid_x = input.dim(1)?.div_ceil(weight_block_size[1]);
635        let scale_stride = grid_x as i32;
636
637        // Allocate output buffers
638        let weight_output = dev.alloc_zeros::<F8E4M3>(input.shape().elem_count())?;
639        let scale_output = dev.alloc_zeros::<f32>(grid_y * grid_x)?;
640
641        let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
642        let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
643
644        match input.dtype() {
645            DType::F32 => {
646                let input_storage = input.storage_and_layout().0;
647                let input_s = match &*input_storage {
648                    Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
649                    _ => candle_core::bail!("Expected CUDA storage"),
650                };
651                let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
652                unsafe {
653                    ffi::launch_quant_fp8_blockwise_kernel_f32(
654                        input_ptr as *const _,
655                        weight_ptr as *mut _,
656                        scale_ptr as *mut _,
657                        weight_height,
658                        weight_width,
659                        weight_row_stride,
660                        scale_stride,
661                        weight_block_size_y,
662                        weight_block_size_x,
663                        dev.cuda_stream().cu_stream(),
664                    )
665                };
666            }
667            DType::F16 => {
668                let input_storage = input.storage_and_layout().0;
669                let input_s = match &*input_storage {
670                    Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
671                    _ => candle_core::bail!("Expected CUDA storage"),
672                };
673                let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
674                unsafe {
675                    ffi::launch_quant_fp8_blockwise_kernel_f16(
676                        input_ptr as *const _,
677                        weight_ptr as *mut _,
678                        scale_ptr as *mut _,
679                        weight_height,
680                        weight_width,
681                        weight_row_stride,
682                        scale_stride,
683                        weight_block_size_y,
684                        weight_block_size_x,
685                        dev.cuda_stream().cu_stream(),
686                    )
687                };
688            }
689            DType::BF16 => {
690                let input_storage = input.storage_and_layout().0;
691                let input_s = match &*input_storage {
692                    Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
693                    _ => candle_core::bail!("Expected CUDA storage"),
694                };
695                let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
696                unsafe {
697                    ffi::launch_quant_fp8_blockwise_kernel_bf16(
698                        input_ptr as *const _,
699                        weight_ptr as *mut _,
700                        scale_ptr as *mut _,
701                        weight_height,
702                        weight_width,
703                        weight_row_stride,
704                        scale_stride,
705                        weight_block_size_y,
706                        weight_block_size_x,
707                        dev.cuda_stream().cu_stream(),
708                    )
709                };
710            }
711            other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
712        }
713
714        // Drop guards before moving the buffers
715        drop(_weight_guard);
716        drop(_scale_guard);
717
718        // Create weight tensor by wrapping the CUDA storage
719        let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
720        let weight =
721            from_storage_no_op(Storage::Cuda(weight_storage), input.shape().clone(), false);
722
723        // Create scale tensor
724        let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
725        let scale = from_storage_no_op(
726            Storage::Cuda(scale_storage),
727            candle_core::Shape::from_dims(&[grid_y, grid_x]),
728            false,
729        );
730
731        Ok((weight, scale))
732    }
733
734    #[cfg(not(feature = "cuda"))]
735    {
736        candle_core::bail!("FP8 blockwise quantization requires CUDA feature");
737    }
738}
739
740#[cfg(test)]
741#[allow(unused_imports)]
742mod tests {
743    use candle_core::{DType, Device, Result, Tensor};
744    use candle_nn::{Linear, Module};
745    use half::bf16;
746    use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
747
748    use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors};
749
750    #[test]
751    fn test_fp8_blockwise_dequant() -> Result<()> {
752        let dev = &Device::Cpu;
753        let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
754        let weight_block_size = vec![2, 2];
755        let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
756
757        let dequant =
758            ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
759
760        let res = dequant.to_vec2::<f32>()?;
761        assert_eq!(
762            res,
763            vec![
764                vec![0., 0., 1., 1., 2.],
765                vec![0., 0., 1., 1., 2.],
766                vec![3., 3., 4., 4., 5.],
767                vec![3., 3., 4., 4., 5.],
768                vec![6., 6., 7., 7., 8.],
769            ]
770        );
771
772        Ok(())
773    }
774
775    #[cfg(feature = "cuda")]
776    #[test]
777    fn test_fp8_blockwise_dequant_cuda() -> Result<()> {
778        let truth = {
779            let dev = &Device::Cpu;
780            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
781            let weight_block_size = vec![2, 2];
782            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
783
784            let dequant =
785                ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
786
787            dequant.to_vec2::<f32>()?
788        };
789        let test = {
790            let dev = &Device::new_cuda(0)?;
791            // Create FP8 weight by first creating on CPU then moving to CUDA
792            let weight_cpu = Tensor::ones((5, 5), DType::F8E4M3, &Device::Cpu)?;
793            let weight = weight_cpu.to_device(dev)?;
794            let weight_block_size = vec![2, 2];
795            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
796
797            let dequant =
798                ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
799
800            dequant.to_vec2::<f32>()?
801        };
802
803        assert_eq!(test, truth);
804        assert_eq!(
805            test,
806            vec![
807                vec![0., 0., 1., 1., 2.],
808                vec![0., 0., 1., 1., 2.],
809                vec![3., 3., 4., 4., 5.],
810                vec![3., 3., 4., 4., 5.],
811                vec![6., 6., 7., 7., 8.],
812            ]
813        );
814
815        Ok(())
816    }
817
818    #[test]
819    fn test_fp8_blockwise_dequant_bf16() -> Result<()> {
820        let dev = &Device::Cpu;
821        let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
822        let weight_block_size = vec![2, 2];
823        let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
824
825        let dequant =
826            ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::BF16)?;
827
828        let res = dequant.to_vec2::<bf16>()?;
829        assert_eq!(
830            res,
831            vec![
832                vec![
833                    bf16::from_f32(0.),
834                    bf16::from_f32(0.),
835                    bf16::from_f32(1.),
836                    bf16::from_f32(1.),
837                    bf16::from_f32(2.)
838                ],
839                vec![
840                    bf16::from_f32(0.),
841                    bf16::from_f32(0.),
842                    bf16::from_f32(1.),
843                    bf16::from_f32(1.),
844                    bf16::from_f32(2.)
845                ],
846                vec![
847                    bf16::from_f32(3.),
848                    bf16::from_f32(3.),
849                    bf16::from_f32(4.),
850                    bf16::from_f32(4.),
851                    bf16::from_f32(5.)
852                ],
853                vec![
854                    bf16::from_f32(3.),
855                    bf16::from_f32(3.),
856                    bf16::from_f32(4.),
857                    bf16::from_f32(4.),
858                    bf16::from_f32(5.)
859                ],
860                vec![
861                    bf16::from_f32(6.),
862                    bf16::from_f32(6.),
863                    bf16::from_f32(7.),
864                    bf16::from_f32(7.),
865                    bf16::from_f32(8.)
866                ],
867            ]
868        );
869
870        Ok(())
871    }
872
873    #[cfg(feature = "cuda")]
874    #[test]
875    fn test_fp8_blockwise_dequant_cuda_bf16() -> Result<()> {
876        let truth = {
877            let dev = &Device::Cpu;
878            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
879            let weight_block_size = vec![2, 2];
880            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
881
882            let dequant = ops::fp8_blockwise_dequantize(
883                &weight,
884                &inv_scales,
885                weight_block_size,
886                DType::BF16,
887            )?;
888
889            dequant.to_vec2::<bf16>()?
890        };
891        let test = {
892            let dev = &Device::new_cuda(0)?;
893            // Create FP8 weight by first creating on CPU then moving to CUDA
894            let weight_cpu = Tensor::ones((5, 5), DType::F8E4M3, &Device::Cpu)?;
895            let weight = weight_cpu.to_device(dev)?;
896            let weight_block_size = vec![2, 2];
897            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
898
899            let dequant = ops::fp8_blockwise_dequantize(
900                &weight,
901                &inv_scales,
902                weight_block_size,
903                DType::BF16,
904            )?;
905
906            dequant.to_vec2::<bf16>()?
907        };
908
909        assert_eq!(test, truth);
910        assert_eq!(
911            test,
912            vec![
913                vec![
914                    bf16::from_f32(0.),
915                    bf16::from_f32(0.),
916                    bf16::from_f32(1.),
917                    bf16::from_f32(1.),
918                    bf16::from_f32(2.)
919                ],
920                vec![
921                    bf16::from_f32(0.),
922                    bf16::from_f32(0.),
923                    bf16::from_f32(1.),
924                    bf16::from_f32(1.),
925                    bf16::from_f32(2.)
926                ],
927                vec![
928                    bf16::from_f32(3.),
929                    bf16::from_f32(3.),
930                    bf16::from_f32(4.),
931                    bf16::from_f32(4.),
932                    bf16::from_f32(5.)
933                ],
934                vec![
935                    bf16::from_f32(3.),
936                    bf16::from_f32(3.),
937                    bf16::from_f32(4.),
938                    bf16::from_f32(4.),
939                    bf16::from_f32(5.)
940                ],
941                vec![
942                    bf16::from_f32(6.),
943                    bf16::from_f32(6.),
944                    bf16::from_f32(7.),
945                    bf16::from_f32(7.),
946                    bf16::from_f32(8.)
947                ],
948            ]
949        );
950
951        Ok(())
952    }
953
954    #[cfg(feature = "cuda")]
955    #[test]
956    fn test_fp8_blockwise_quant_dequant_roundtrip() -> Result<()> {
957        let dev = &Device::new_cuda(0)?;
958
959        // Create test input
960        let input = Tensor::randn(0f32, 2f32, (8, 8), dev)?;
961        let weight_block_size = vec![4, 4];
962
963        // Quantize
964        let (quantized, scales) = ops::fp8_blockwise_quantize(&input, weight_block_size.clone())?;
965
966        // Verify shapes
967        assert_eq!(quantized.shape(), input.shape());
968        assert_eq!(scales.dims2()?, (2, 2)); // 8/4 = 2 blocks in each dimension
969
970        // Dequantize
971        let dequantized =
972            ops::fp8_blockwise_dequantize(&quantized, &scales, weight_block_size, input.dtype())?;
973
974        // Check that shapes match
975        assert_eq!(dequantized.shape(), input.shape());
976
977        // The values won't be exactly the same due to quantization loss,
978        // but they should be reasonably close
979        let input_vec = input.to_vec2::<f32>()?;
980        let dequant_vec = dequantized.to_vec2::<f32>()?;
981
982        let mut max_error = 0f32;
983        for (row_in, row_out) in input_vec.iter().zip(dequant_vec.iter()) {
984            for (val_in, val_out) in row_in.iter().zip(row_out.iter()) {
985                let error = (val_in - val_out).abs();
986                max_error = max_error.max(error);
987            }
988        }
989
990        // FP8 E4M3 has limited precision, so we expect some error
991        // but it should be reasonable
992        assert!(max_error < 0.16, "Max error {} is too large", max_error);
993
994        Ok(())
995    }
996
997    #[cfg(feature = "cuda")]
998    #[test]
999    fn test_blockwise_fp8_gemm() -> Result<()> {
1000        let dev = Device::cuda_if_available(0)?;
1001
1002        let api = ApiBuilder::new().with_progress(true).build().unwrap();
1003        let api = api.repo(Repo::with_revision(
1004            "EricB/mistralrs_tests".to_string(),
1005            RepoType::Model,
1006            "main".to_string(),
1007        ));
1008
1009        let filename = api.get("test_fp8.safetensors").unwrap();
1010        let vb = unsafe { MmapedSafetensors::new(filename)? };
1011
1012        let weight = vb.load("weight", &dev, None)?;
1013        assert_eq!((7168, 2048), weight.dims2()?);
1014        assert_eq!(DType::F8E4M3, weight.dtype());
1015
1016        let scale = vb.load("scale", &dev, None)?;
1017        assert_eq!((56, 16), scale.dims2()?);
1018        assert_eq!(DType::F32, scale.dtype());
1019
1020        let weight_block_size = vec![128, 128];
1021
1022        // in dim is 2048.
1023        let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?;
1024
1025        let truth = {
1026            let weight_dq =
1027                ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?;
1028
1029            let lin_dq = Linear::new(weight_dq, None);
1030            lin_dq.forward(&xs)?
1031        };
1032
1033        // TODO: will be adding real blockwise fp8 gemm shortly ;)
1034        assert_eq!((32, 7168), truth.dims2()?);
1035
1036        Ok(())
1037    }
1038}