mistralrs_quant/blockwise_fp8/
ops.rs

1use candle_core::{CpuStorage, CustomOp1, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5struct Fp8BlockwiseDequantize {
6    weight_block_size: Vec<usize>,
7    out_ty: DType,
8}
9
10impl Fp8BlockwiseDequantize {
11    fn dispatch_dequant_blockwise<T: WithDType>(
12        &self,
13        weight: &[F8E4M3],
14        scale: &[f32],
15        weight_l: &candle_core::Layout,
16        scale_l: &candle_core::Layout,
17    ) -> candle_core::Result<Vec<T>> {
18        let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]);
19        let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]);
20
21        let res = vec![T::zero(); weight.len()];
22
23        (0..grid_y).into_par_iter().for_each(|y| {
24            (0..grid_x).into_par_iter().for_each(|x| {
25                let res_ptr = res.as_ptr() as *mut T;
26
27                let scale = scale[y * scale_l.stride()[0] + x];
28
29                let start_y = y * self.weight_block_size[0];
30                let end_y = start_y + self.weight_block_size[0];
31
32                let start_x = x * self.weight_block_size[1];
33                let end_x = start_x + self.weight_block_size[1];
34
35                for weight_y in start_y..end_y {
36                    if weight_y >= weight_l.dims()[0] {
37                        break;
38                    }
39
40                    let row_offset = weight_y * weight_l.stride()[0];
41                    for weight_x in start_x..end_x {
42                        if weight_x >= weight_l.dims()[1] {
43                            break;
44                        }
45
46                        let weight_pos = row_offset + weight_x;
47
48                        // SAFETY: We know each thread will only update indepedant values!
49                        unsafe {
50                            *res_ptr.wrapping_add(weight_pos) =
51                                T::from_f64((weight[weight_pos].to_f32() * scale) as f64);
52                        }
53                    }
54                }
55            });
56        });
57
58        Ok(res)
59    }
60}
61
62impl CustomOp2 for Fp8BlockwiseDequantize {
63    fn name(&self) -> &'static str {
64        "fp8-blockwise-dequantize"
65    }
66
67    fn cpu_fwd(
68        &self,
69        scale_s: &candle_core::CpuStorage,
70        scale_l: &candle_core::Layout,
71        weight_s: &candle_core::CpuStorage,
72        weight_l: &candle_core::Layout,
73    ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
74        let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
75            candle_core::bail!("Expected F8E4M3 weight!");
76        };
77        let candle_core::CpuStorage::F32(scale) = scale_s else {
78            candle_core::bail!("Expected F8E4M3 weight!");
79        };
80        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
81            candle_core::bail!("Expected weight to have start offset 0, continuous");
82        }
83        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
84            candle_core::bail!("Expected scales to have start offset 0, continuous");
85        }
86        if weight_l.dims().len() != 2 {
87            candle_core::bail!("Expected weight to be rank 2");
88        }
89        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
90            candle_core::bail!("Expected scale to be rank 2");
91        }
92
93        match self.out_ty {
94            DType::F32 => Ok((
95                CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
96                weight_l.shape().clone(),
97            )),
98            DType::BF16 => Ok((
99                CpuStorage::BF16(
100                    self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?,
101                ),
102                weight_l.shape().clone(),
103            )),
104            DType::F16 => Ok((
105                CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
106                weight_l.shape().clone(),
107            )),
108            other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
109        }
110    }
111
112    #[cfg(feature = "cuda")]
113    fn cuda_fwd(
114        &self,
115        scale_s: &candle_core::CudaStorage,
116        scale_l: &candle_core::Layout,
117        weight_s: &candle_core::CudaStorage,
118        weight_l: &candle_core::Layout,
119    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
120        use candle_core::{backend::BackendStorage, CudaStorage};
121        use half::{bf16, f16};
122
123        use crate::{blockwise_fp8::ffi, utils::slice_ptr};
124
125        if !ffi::HAVE_BLOCKWISE_DEQUANT_KERNELS {
126            candle_core::bail!("Do not have blockwise FP8 dequant kernels.");
127        }
128
129        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
130            candle_core::bail!("Expected weight to have start offset 0, continuous");
131        }
132        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
133            candle_core::bail!("Expected scales to have start offset 0, continuous");
134        }
135        if weight_l.dims().len() != 2 {
136            candle_core::bail!("Expected weight to be rank 2");
137        }
138        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
139            candle_core::bail!("Expected scale to be rank 2");
140        }
141
142        let dev = weight_s.device();
143
144        let (weight, _weight_guard) =
145            slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
146        let (scale, _scale_guard) =
147            slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
148
149        let weight_height = weight_l.dim(0)? as i32;
150        let weight_block_size_x = self.weight_block_size[0] as i32;
151        let weight_width = weight_l.dim(1)? as i32;
152        let weight_block_size_y = self.weight_block_size[1] as i32;
153        let scale_stride = scale_l.stride()[0] as i32;
154        let weight_row_stride = weight_l.stride()[0] as i32;
155
156        let res = match self.out_ty {
157            DType::F32 => {
158                let output = weight_s
159                    .device()
160                    .alloc_zeros::<f32>(weight_l.shape().elem_count())?;
161                let (output_ptr, output_guard) = slice_ptr(&output, 0);
162                unsafe {
163                    ffi::launch_dequant_fp8_blockwise_kernel_f32(
164                        weight as *const _,
165                        scale as *const _,
166                        output_ptr as *mut _,
167                        weight_height,
168                        weight_width,
169                        weight_row_stride,
170                        scale_stride,
171                        weight_block_size_y,
172                        weight_block_size_x,
173                        dev.cuda_stream().cu_stream(),
174                    )
175                };
176                drop(output_guard);
177                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
178            }
179            DType::F16 => {
180                let output = weight_s
181                    .device()
182                    .alloc_zeros::<f16>(weight_l.shape().elem_count())?;
183                let (output_ptr, output_guard) = slice_ptr(&output, 0);
184                unsafe {
185                    ffi::launch_dequant_fp8_blockwise_kernel_f16(
186                        weight as *const _,
187                        scale as *const _,
188                        output_ptr as *mut _,
189                        weight_height,
190                        weight_width,
191                        weight_row_stride,
192                        scale_stride,
193                        weight_block_size_y,
194                        weight_block_size_x,
195                        dev.cuda_stream().cu_stream(),
196                    )
197                };
198                drop(output_guard);
199                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
200            }
201            DType::BF16 => {
202                let output = weight_s
203                    .device()
204                    .alloc_zeros::<bf16>(weight_l.shape().elem_count())?;
205                let (output_ptr, output_guard) = slice_ptr(&output, 0);
206                unsafe {
207                    ffi::launch_dequant_fp8_blockwise_kernel_bf16(
208                        weight as *const _,
209                        scale as *const _,
210                        output_ptr as *mut _,
211                        weight_height,
212                        weight_width,
213                        weight_row_stride,
214                        scale_stride,
215                        weight_block_size_y,
216                        weight_block_size_x,
217                        dev.cuda_stream().cu_stream(),
218                    )
219                };
220                drop(output_guard);
221                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
222            }
223            other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
224        };
225
226        Ok((res, weight_l.shape().clone()))
227    }
228
229    #[cfg(feature = "metal")]
230    fn metal_fwd(
231        &self,
232        scale_s: &candle_core::MetalStorage,
233        scale_l: &candle_core::Layout,
234        weight_s: &candle_core::MetalStorage,
235        weight_l: &candle_core::Layout,
236    ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
237        use candle_core::backend::BackendStorage;
238
239        if weight_l.start_offset() != 0
240            || !weight_l.is_contiguous()
241            || weight_s.dtype() != DType::F8E4M3
242        {
243            candle_core::bail!("Expected f8e4m3 weight to have start offset 0, continuous");
244        }
245        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() || scale_s.dtype() != DType::F32
246        {
247            candle_core::bail!("Expected f32 scales to have start offset 0, continuous");
248        }
249        if weight_l.dims().len() != 2 {
250            candle_core::bail!("Expected weight to be rank 2");
251        }
252        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
253            candle_core::bail!("Expected scale to be rank 2");
254        }
255
256        let encoder = weight_s.device().command_encoder()?;
257        encoder.set_label("dequant-blockwise-fp8");
258
259        let device = weight_s.device();
260
261        let out_shape = weight_l.shape().clone();
262
263        let output = device.new_buffer(
264            out_shape.elem_count(),
265            weight_s.dtype(),
266            "dequant-blockwise-fp8",
267        )?;
268
269        let weight_height = weight_l.dim(0)? as u32;
270        let weight_block_size_x = self.weight_block_size[0] as u32;
271        let weight_width = weight_l.dim(1)? as u32;
272        let weight_block_size_y = self.weight_block_size[1] as u32;
273        let scale_stride = scale_l.stride()[0] as u32;
274        let weight_row_stride = weight_l.stride()[0] as u32;
275
276        crate::metal_kernels::call_dequant_blockwise_fp8(
277            device.device(),
278            &encoder,
279            &crate::metal_kernels::Kernels::new(),
280            self.out_ty,
281            weight_s.buffer(),
282            scale_s.buffer(),
283            &output,
284            weight_height,
285            weight_width,
286            weight_row_stride,
287            scale_stride,
288            weight_block_size_y,
289            weight_block_size_x,
290        )
291        .map_err(candle_core::Error::wrap)?;
292
293        let newstorage = candle_core::MetalStorage::new(
294            output,
295            device.clone(),
296            out_shape.elem_count(),
297            self.out_ty,
298        );
299        Ok((newstorage, out_shape))
300    }
301}
302
303/// FP8 blockwise dequantize.
304/// - Expects weight to be fp8
305/// - Expects inv_scales to be f32
306/// - weight * inv_scale = dequantized
307pub fn fp8_blockwise_dequantize(
308    weight: &Tensor,
309    inv_scales: &Tensor,
310    weight_block_size: Vec<usize>,
311    out_ty: DType,
312) -> Result<Tensor> {
313    inv_scales.apply_op2_no_bwd(
314        weight,
315        &Fp8BlockwiseDequantize {
316            weight_block_size,
317            out_ty,
318        },
319    )
320}
321
322#[allow(dead_code)]
323struct Fp8BlockwiseQuantize {
324    weight_block_size: Vec<usize>,
325}
326
327impl Fp8BlockwiseQuantize {
328    #[allow(dead_code)]
329    fn dispatch_quant_blockwise<T: WithDType>(
330        &self,
331        input: &[T],
332        input_l: &candle_core::Layout,
333    ) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
334        let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
335        let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
336
337        let weight = vec![F8E4M3::from_f32(0.0); input.len()];
338        let scale = vec![0f32; grid_y * grid_x];
339
340        (0..grid_y).into_par_iter().for_each(|y| {
341            (0..grid_x).into_par_iter().for_each(|x| {
342                let weight_ptr = weight.as_ptr() as *mut F8E4M3;
343                let scale_ptr = scale.as_ptr() as *mut f32;
344
345                let start_y = y * self.weight_block_size[0];
346                let end_y = start_y + self.weight_block_size[0];
347
348                let start_x = x * self.weight_block_size[1];
349                let end_x = start_x + self.weight_block_size[1];
350
351                // Find max absolute value in block
352                let mut max_abs = 0f32;
353                for weight_y in start_y..end_y {
354                    if weight_y >= input_l.dims()[0] {
355                        break;
356                    }
357
358                    let row_offset = weight_y * input_l.stride()[0];
359                    for weight_x in start_x..end_x {
360                        if weight_x >= input_l.dims()[1] {
361                            break;
362                        }
363
364                        let pos = row_offset + weight_x;
365                        let val = input[pos].to_f64() as f32;
366                        let abs_val = val.abs();
367                        if abs_val > max_abs {
368                            max_abs = abs_val;
369                        }
370                    }
371                }
372
373                // Calculate scale
374                let block_scale = if max_abs > 0.0 {
375                    max_abs / 448.0
376                } else {
377                    1e-12
378                };
379
380                // SAFETY: We know each thread will only update independent values!
381                unsafe {
382                    *scale_ptr.wrapping_add(y * grid_x + x) = block_scale;
383                }
384
385                // Quantize values
386                for weight_y in start_y..end_y {
387                    if weight_y >= input_l.dims()[0] {
388                        break;
389                    }
390
391                    let row_offset = weight_y * input_l.stride()[0];
392                    for weight_x in start_x..end_x {
393                        if weight_x >= input_l.dims()[1] {
394                            break;
395                        }
396
397                        let pos = row_offset + weight_x;
398                        let val = input[pos].to_f64() as f32;
399                        let scaled_val = (val / block_scale).clamp(-448.0, 448.0);
400
401                        // SAFETY: We know each thread will only update independent values!
402                        unsafe {
403                            *weight_ptr.wrapping_add(pos) = F8E4M3::from_f32(scaled_val);
404                        }
405                    }
406                }
407            });
408        });
409
410        Ok((weight, scale))
411    }
412}
413
414impl CustomOp1 for Fp8BlockwiseQuantize {
415    fn name(&self) -> &'static str {
416        "fp8-blockwise-quantize"
417    }
418
419    fn cpu_fwd(
420        &self,
421        input_s: &candle_core::CpuStorage,
422        input_l: &candle_core::Layout,
423    ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
424        if input_l.start_offset() != 0 || !input_l.is_contiguous() {
425            candle_core::bail!("Expected input to have start offset 0, continuous");
426        }
427        if input_l.dims().len() != 2 {
428            candle_core::bail!("Expected input to be rank 2");
429        }
430        if self.weight_block_size.len() != 2 {
431            candle_core::bail!("Expected weight_block_size to have length 2");
432        }
433
434        let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
435        let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
436
437        let (weight, scale) = match input_s {
438            CpuStorage::F32(input) => self.dispatch_quant_blockwise(input, input_l)?,
439            CpuStorage::F16(input) => self.dispatch_quant_blockwise(input, input_l)?,
440            CpuStorage::BF16(input) => self.dispatch_quant_blockwise(input, input_l)?,
441            other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
442        };
443
444        // Return both weight and scale tensors packed into a single storage
445        // We'll need to unpack them after the op
446        let mut packed = Vec::with_capacity(weight.len() + scale.len());
447        packed.extend_from_slice(&weight);
448
449        // Convert scale to F8E4M3 for storage (will convert back when unpacking)
450        for &s in &scale {
451            packed.push(F8E4M3::from_f32(s));
452        }
453
454        Ok((
455            CpuStorage::F8E4M3(packed),
456            candle_core::Shape::from_dims(&[
457                input_l.dims()[0] + grid_y,
458                input_l.dims()[1].max(grid_x),
459            ]),
460        ))
461    }
462
463    #[cfg(feature = "cuda")]
464    fn cuda_fwd(
465        &self,
466        input_s: &candle_core::CudaStorage,
467        input_l: &candle_core::Layout,
468    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
469        use candle_core::{backend::BackendStorage, CudaStorage};
470        use half::{bf16, f16};
471
472        use crate::{blockwise_fp8::ffi, utils::slice_ptr};
473
474        if !ffi::HAVE_BLOCKWISE_QUANT_KERNELS {
475            candle_core::bail!("Do not have blockwise FP8 quant kernels.");
476        }
477
478        if input_l.start_offset() != 0 || !input_l.is_contiguous() {
479            candle_core::bail!("Expected input to have start offset 0, continuous");
480        }
481        if input_l.dims().len() != 2 {
482            candle_core::bail!("Expected input to be rank 2");
483        }
484        if self.weight_block_size.len() != 2 {
485            candle_core::bail!("Expected weight_block_size to have length 2");
486        }
487
488        let dev = input_s.device();
489
490        let weight_height = input_l.dim(0)? as i32;
491        let weight_block_size_y = self.weight_block_size[0] as i32;
492        let weight_width = input_l.dim(1)? as i32;
493        let weight_block_size_x = self.weight_block_size[1] as i32;
494        let weight_row_stride = input_l.stride()[0] as i32;
495
496        let grid_y = input_l.dim(0)?.div_ceil(self.weight_block_size[0]);
497        let grid_x = input_l.dim(1)?.div_ceil(self.weight_block_size[1]);
498        let scale_stride = grid_x as i32;
499
500        // Allocate output buffers
501        let weight_output = dev.alloc_zeros::<F8E4M3>(input_l.shape().elem_count())?;
502        let scale_output = dev.alloc_zeros::<f32>(grid_y * grid_x)?;
503
504        let (weight_ptr, weight_guard) = slice_ptr(&weight_output, 0);
505        let (scale_ptr, scale_guard) = slice_ptr(&scale_output, 0);
506
507        match input_s.dtype() {
508            DType::F32 => {
509                let (input, _input_guard) =
510                    slice_ptr(input_s.as_cuda_slice::<f32>()?, input_l.start_offset());
511                unsafe {
512                    ffi::launch_quant_fp8_blockwise_kernel_f32(
513                        input as *const _,
514                        weight_ptr as *mut _,
515                        scale_ptr as *mut _,
516                        weight_height,
517                        weight_width,
518                        weight_row_stride,
519                        scale_stride,
520                        weight_block_size_y,
521                        weight_block_size_x,
522                        dev.cuda_stream().cu_stream(),
523                    )
524                };
525            }
526            DType::F16 => {
527                let (input, _input_guard) =
528                    slice_ptr(input_s.as_cuda_slice::<f16>()?, input_l.start_offset());
529                unsafe {
530                    ffi::launch_quant_fp8_blockwise_kernel_f16(
531                        input as *const _,
532                        weight_ptr as *mut _,
533                        scale_ptr as *mut _,
534                        weight_height,
535                        weight_width,
536                        weight_row_stride,
537                        scale_stride,
538                        weight_block_size_y,
539                        weight_block_size_x,
540                        dev.cuda_stream().cu_stream(),
541                    )
542                };
543            }
544            DType::BF16 => {
545                let (input, _input_guard) =
546                    slice_ptr(input_s.as_cuda_slice::<bf16>()?, input_l.start_offset());
547                unsafe {
548                    ffi::launch_quant_fp8_blockwise_kernel_bf16(
549                        input as *const _,
550                        weight_ptr as *mut _,
551                        scale_ptr as *mut _,
552                        weight_height,
553                        weight_width,
554                        weight_row_stride,
555                        scale_stride,
556                        weight_block_size_y,
557                        weight_block_size_x,
558                        dev.cuda_stream().cu_stream(),
559                    )
560                };
561            }
562            other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
563        }
564
565        drop(weight_guard);
566        drop(scale_guard);
567
568        // Return just the weight tensor - we'll handle scale separately
569        let res = CudaStorage::wrap_cuda_slice(weight_output, input_s.device().clone());
570        Ok((res, input_l.shape().clone()))
571    }
572
573    #[cfg(feature = "metal")]
574    fn metal_fwd(
575        &self,
576        _input_s: &candle_core::MetalStorage,
577        _input_l: &candle_core::Layout,
578    ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
579        candle_core::bail!("FP8 blockwise quantization not yet implemented for Metal");
580    }
581}
582
583/// FP8 blockwise quantize.
584/// - Expects input to be f32, f16, or bf16
585/// - Returns a tuple of (quantized_weight, scales)
586/// - quantized_weight is fp8
587/// - scales is f32
588pub fn fp8_blockwise_quantize(
589    #[allow(unused_variables)] input: &Tensor,
590    #[allow(unused_variables)] weight_block_size: Vec<usize>,
591) -> Result<(Tensor, Tensor)> {
592    // Since CustomOp1 only returns a single tensor, we need a different approach
593    // Let's implement this using the CUDA kernels directly
594    #[cfg(feature = "cuda")]
595    {
596        use candle_core::{CudaStorage, Device, Storage};
597        use half::{bf16, f16};
598
599        use crate::{blockwise_fp8::ffi, utils::slice_ptr};
600
601        if !matches!(input.device(), Device::Cuda(_)) {
602            candle_core::bail!("FP8 blockwise quantization only supported on CUDA for now");
603        }
604
605        if !ffi::HAVE_BLOCKWISE_QUANT_KERNELS {
606            candle_core::bail!("Do not have blockwise FP8 quant kernels.");
607        }
608
609        let input_l = input.layout();
610        if input_l.start_offset() != 0 || !input_l.is_contiguous() {
611            candle_core::bail!("Expected input to have start offset 0, continuous");
612        }
613        if input.dims().len() != 2 {
614            candle_core::bail!("Expected input to be rank 2");
615        }
616        if weight_block_size.len() != 2 {
617            candle_core::bail!("Expected weight_block_size to have length 2");
618        }
619
620        let dev = match input.device() {
621            Device::Cuda(dev) => dev,
622            _ => unreachable!(),
623        };
624
625        let weight_height = input.dim(0)? as i32;
626        let weight_block_size_y = weight_block_size[0] as i32;
627        let weight_width = input.dim(1)? as i32;
628        let weight_block_size_x = weight_block_size[1] as i32;
629        let weight_row_stride = input_l.stride()[0] as i32;
630
631        let grid_y = input.dim(0)?.div_ceil(weight_block_size[0]);
632        let grid_x = input.dim(1)?.div_ceil(weight_block_size[1]);
633        let scale_stride = grid_x as i32;
634
635        // Allocate output buffers
636        let weight_output = dev.alloc_zeros::<F8E4M3>(input.shape().elem_count())?;
637        let scale_output = dev.alloc_zeros::<f32>(grid_y * grid_x)?;
638
639        let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
640        let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
641
642        match input.dtype() {
643            DType::F32 => {
644                let input_storage = input.storage_and_layout().0;
645                let input_s = match &*input_storage {
646                    Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
647                    _ => candle_core::bail!("Expected CUDA storage"),
648                };
649                let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
650                unsafe {
651                    ffi::launch_quant_fp8_blockwise_kernel_f32(
652                        input_ptr as *const _,
653                        weight_ptr as *mut _,
654                        scale_ptr as *mut _,
655                        weight_height,
656                        weight_width,
657                        weight_row_stride,
658                        scale_stride,
659                        weight_block_size_y,
660                        weight_block_size_x,
661                        dev.cuda_stream().cu_stream(),
662                    )
663                };
664            }
665            DType::F16 => {
666                let input_storage = input.storage_and_layout().0;
667                let input_s = match &*input_storage {
668                    Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
669                    _ => candle_core::bail!("Expected CUDA storage"),
670                };
671                let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
672                unsafe {
673                    ffi::launch_quant_fp8_blockwise_kernel_f16(
674                        input_ptr as *const _,
675                        weight_ptr as *mut _,
676                        scale_ptr as *mut _,
677                        weight_height,
678                        weight_width,
679                        weight_row_stride,
680                        scale_stride,
681                        weight_block_size_y,
682                        weight_block_size_x,
683                        dev.cuda_stream().cu_stream(),
684                    )
685                };
686            }
687            DType::BF16 => {
688                let input_storage = input.storage_and_layout().0;
689                let input_s = match &*input_storage {
690                    Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
691                    _ => candle_core::bail!("Expected CUDA storage"),
692                };
693                let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
694                unsafe {
695                    ffi::launch_quant_fp8_blockwise_kernel_bf16(
696                        input_ptr as *const _,
697                        weight_ptr as *mut _,
698                        scale_ptr as *mut _,
699                        weight_height,
700                        weight_width,
701                        weight_row_stride,
702                        scale_stride,
703                        weight_block_size_y,
704                        weight_block_size_x,
705                        dev.cuda_stream().cu_stream(),
706                    )
707                };
708            }
709            other => candle_core::bail!("unexpected input type for fp8 blockwise quant: {other:?}"),
710        }
711
712        // Drop guards before moving the buffers
713        drop(_weight_guard);
714        drop(_scale_guard);
715
716        // Create weight tensor by wrapping the CUDA storage
717        let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
718        let weight = Tensor::from((Storage::Cuda(weight_storage), input.shape().clone()));
719
720        // Create scale tensor
721        let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
722        let scale = Tensor::from((
723            Storage::Cuda(scale_storage),
724            candle_core::Shape::from_dims(&[grid_y, grid_x]),
725        ));
726
727        Ok((weight, scale))
728    }
729
730    #[cfg(not(feature = "cuda"))]
731    {
732        candle_core::bail!("FP8 blockwise quantization requires CUDA feature");
733    }
734}
735
736/// FP8 blockwise matmul.
737/// Computes output = input @ weight.T where weight is FP8 blockwise quantized.
738/// - input: [M, K] in fp16/bf16
739/// - weight: [N, K] in FP8 with blockwise scales
740/// - scales: [N/block_y, K/block_x] in f32
741/// - output: [M, N] in fp16/bf16
742#[cfg(feature = "cuda")]
743pub fn fp8_blockwise_matmul(
744    input: &Tensor,
745    weight: &Tensor,
746    scales: &Tensor,
747    weight_block_size: &[usize],
748) -> Result<Tensor> {
749    use candle_core::{CudaStorage, Device, Storage};
750    use half::{bf16, f16};
751
752    use crate::{blockwise_fp8::ffi, utils::slice_ptr};
753
754    if !ffi::HAVE_BLOCKWISE_GEMM_KERNELS {
755        candle_core::bail!("Do not have blockwise FP8 GEMM kernels.");
756    }
757
758    if !matches!(input.device(), Device::Cuda(_)) {
759        candle_core::bail!("FP8 blockwise matmul only supported on CUDA");
760    }
761
762    let input = input.contiguous()?;
763    let weight = weight.contiguous()?;
764    let scales = scales.contiguous()?;
765
766    if input.dims().len() != 2 {
767        candle_core::bail!("Expected input to be rank 2, got {:?}", input.dims());
768    }
769    if weight.dims().len() != 2 {
770        candle_core::bail!("Expected weight to be rank 2, got {:?}", weight.dims());
771    }
772    if weight.dtype() != DType::F8E4M3 {
773        candle_core::bail!("Expected FP8 weight, got {:?}", weight.dtype());
774    }
775
776    let m = input.dim(0)? as i32;
777    let k = input.dim(1)? as i32;
778    let n = weight.dim(0)? as i32;
779
780    if weight.dim(1)? as i32 != k {
781        candle_core::bail!(
782            "Weight K dimension {} doesn't match input K dimension {}",
783            weight.dim(1)?,
784            k
785        );
786    }
787
788    let dev = match input.device() {
789        Device::Cuda(dev) => dev,
790        _ => unreachable!(),
791    };
792
793    let block_size_y = weight_block_size[0] as i32;
794    let block_size_x = weight_block_size[1] as i32;
795    let scale_row_stride = scales.dim(1)? as i32;
796
797    let input_l = input.layout();
798    let weight_l = weight.layout();
799    let scales_l = scales.layout();
800
801    let input_storage = input.storage_and_layout().0;
802    let weight_storage = weight.storage_and_layout().0;
803    let scales_storage = scales.storage_and_layout().0;
804
805    let weight_s = match &*weight_storage {
806        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<F8E4M3>()?,
807        _ => candle_core::bail!("Expected CUDA storage for weight"),
808    };
809    let scales_s = match &*scales_storage {
810        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
811        _ => candle_core::bail!("Expected CUDA storage for scales"),
812    };
813
814    let (weight_ptr, _weight_guard) = slice_ptr(weight_s, weight_l.start_offset());
815    let (scales_ptr, _scales_guard) = slice_ptr(scales_s, scales_l.start_offset());
816
817    match input.dtype() {
818        DType::F16 => {
819            let output = dev.alloc_zeros::<f16>((m * n) as usize)?;
820
821            let input_s = match &*input_storage {
822                Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
823                _ => candle_core::bail!("Expected CUDA storage for input"),
824            };
825
826            {
827                let (output_ptr, _output_guard) = slice_ptr(&output, 0);
828                let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
829
830                unsafe {
831                    ffi::launch_fp8_matmul_f16(
832                        input_ptr as *const _,
833                        weight_ptr as *const _,
834                        scales_ptr as *const _,
835                        output_ptr as *mut _,
836                        m,
837                        n,
838                        k,
839                        scale_row_stride,
840                        block_size_y,
841                        block_size_x,
842                        dev.cuda_stream().cu_stream(),
843                    )
844                };
845            }
846
847            let output_storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
848            Ok(Tensor::from((
849                Storage::Cuda(output_storage),
850                candle_core::Shape::from_dims(&[m as usize, n as usize]),
851            )))
852        }
853        DType::BF16 => {
854            let output = dev.alloc_zeros::<bf16>((m * n) as usize)?;
855
856            let input_s = match &*input_storage {
857                Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
858                _ => candle_core::bail!("Expected CUDA storage for input"),
859            };
860
861            {
862                let (output_ptr, _output_guard) = slice_ptr(&output, 0);
863                let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
864
865                unsafe {
866                    ffi::launch_fp8_matmul_bf16(
867                        input_ptr as *const _,
868                        weight_ptr as *const _,
869                        scales_ptr as *const _,
870                        output_ptr as *mut _,
871                        m,
872                        n,
873                        k,
874                        scale_row_stride,
875                        block_size_y,
876                        block_size_x,
877                        dev.cuda_stream().cu_stream(),
878                    )
879                };
880            }
881
882            let output_storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
883            Ok(Tensor::from((
884                Storage::Cuda(output_storage),
885                candle_core::Shape::from_dims(&[m as usize, n as usize]),
886            )))
887        }
888        other => candle_core::bail!("Unsupported input dtype for FP8 matmul: {:?}", other),
889    }
890}
891
892/// FP8 indexed MoE GEMM for gather_forward.
893/// Computes indexed matmul for MoE where each token selects specific experts.
894/// - input: [num_tokens, 1, K] or [num_tokens, topk, K] in fp16/bf16
895/// - weights: [num_experts, N, K] in FP8 with blockwise scales
896/// - scales: [num_experts, N/block_y, K/block_x] in f32
897/// - indices: [num_tokens, topk] in i32
898/// - output: [num_tokens, topk, N] in fp16/bf16
899#[cfg(feature = "cuda")]
900pub fn fp8_indexed_moe_gemm(
901    input: &Tensor,
902    weights: &Tensor,
903    scales: &Tensor,
904    indices: &Tensor,
905    weight_block_size: &[usize],
906) -> Result<Tensor> {
907    use candle_core::{CudaStorage, Device, Storage};
908    use half::{bf16, f16};
909
910    use crate::{blockwise_fp8::ffi, utils::slice_ptr};
911
912    if !ffi::HAVE_BLOCKWISE_GEMM_KERNELS {
913        candle_core::bail!("Do not have blockwise FP8 GEMM kernels.");
914    }
915
916    if !matches!(input.device(), Device::Cuda(_)) {
917        candle_core::bail!("FP8 indexed MoE GEMM only supported on CUDA");
918    }
919
920    let input = input.contiguous()?;
921    let weights = weights.contiguous()?;
922    let scales = scales.contiguous()?;
923    let indices = indices.contiguous()?;
924
925    // Determine input shape
926    // Input can be [num_tokens, 1, K] or [num_tokens, topk, K]
927    let (num_tokens, input_has_topk_dim, k) = if input.dims().len() == 3 {
928        let dims = input.dims3()?;
929        (dims.0, dims.1 > 1, dims.2)
930    } else if input.dims().len() == 2 {
931        let dims = input.dims2()?;
932        (dims.0, false, dims.1)
933    } else {
934        candle_core::bail!("Expected input to be rank 2 or 3, got {:?}", input.dims());
935    };
936
937    // Get topk from indices
938    let (indices_tokens, topk) = indices.dims2()?;
939    if indices_tokens != num_tokens {
940        candle_core::bail!(
941            "Indices num_tokens {} doesn't match input num_tokens {}",
942            indices_tokens,
943            num_tokens
944        );
945    }
946
947    // Weights shape: [num_experts, N, K]
948    if weights.dims().len() != 3 {
949        candle_core::bail!("Expected weights to be rank 3, got {:?}", weights.dims());
950    }
951    let (num_experts, n, weight_k) = weights.dims3()?;
952    if weight_k != k {
953        candle_core::bail!(
954            "Weights K dimension {} doesn't match input K dimension {}",
955            weight_k,
956            k
957        );
958    }
959
960    if weights.dtype() != DType::F8E4M3 {
961        candle_core::bail!("Expected FP8 weights, got {:?}", weights.dtype());
962    }
963
964    let dev = match input.device() {
965        Device::Cuda(dev) => dev,
966        _ => unreachable!(),
967    };
968
969    let block_size_y = weight_block_size[0] as i32;
970    let block_size_x = weight_block_size[1] as i32;
971
972    // Scales shape should be [num_experts, N/block_y, K/block_x]
973    let scale_row_stride = scales.dim(2)? as i32; // K/block_x
974
975    let input_l = input.layout();
976    let weights_l = weights.layout();
977    let scales_l = scales.layout();
978    let indices_l = indices.layout();
979
980    let input_storage = input.storage_and_layout().0;
981    let weights_storage = weights.storage_and_layout().0;
982    let scales_storage = scales.storage_and_layout().0;
983    let indices_storage = indices.storage_and_layout().0;
984
985    let weights_s = match &*weights_storage {
986        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<F8E4M3>()?,
987        _ => candle_core::bail!("Expected CUDA storage for weights"),
988    };
989    let scales_s = match &*scales_storage {
990        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
991        _ => candle_core::bail!("Expected CUDA storage for scales"),
992    };
993    let indices_s = match &*indices_storage {
994        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<u32>()?,
995        _ => candle_core::bail!("Expected CUDA storage for indices"),
996    };
997
998    let (weights_ptr, _weights_guard) = slice_ptr(weights_s, weights_l.start_offset());
999    let (scales_ptr, _scales_guard) = slice_ptr(scales_s, scales_l.start_offset());
1000    let (indices_ptr, _indices_guard) = slice_ptr(indices_s, indices_l.start_offset());
1001
1002    match input.dtype() {
1003        DType::F16 => {
1004            let output = dev.alloc_zeros::<f16>(num_tokens * topk * n)?;
1005
1006            let input_s = match &*input_storage {
1007                Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
1008                _ => candle_core::bail!("Expected CUDA storage for input"),
1009            };
1010
1011            {
1012                let (output_ptr, _output_guard) = slice_ptr(&output, 0);
1013                let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
1014
1015                unsafe {
1016                    ffi::launch_fp8_indexed_moe_gemm_f16(
1017                        input_ptr as *const _,
1018                        weights_ptr as *const _,
1019                        scales_ptr as *const _,
1020                        indices_ptr as *const _,
1021                        output_ptr as *mut _,
1022                        num_tokens as i32,
1023                        topk as i32,
1024                        num_experts as i32,
1025                        n as i32,
1026                        k as i32,
1027                        scale_row_stride,
1028                        block_size_y,
1029                        block_size_x,
1030                        input_has_topk_dim,
1031                        dev.cuda_stream().cu_stream(),
1032                    )
1033                };
1034            }
1035
1036            let output_storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
1037            Ok(Tensor::from((
1038                Storage::Cuda(output_storage),
1039                candle_core::Shape::from_dims(&[num_tokens, topk, n]),
1040            )))
1041        }
1042        DType::BF16 => {
1043            let output = dev.alloc_zeros::<bf16>(num_tokens * topk * n)?;
1044
1045            let input_s = match &*input_storage {
1046                Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
1047                _ => candle_core::bail!("Expected CUDA storage for input"),
1048            };
1049
1050            {
1051                let (output_ptr, _output_guard) = slice_ptr(&output, 0);
1052                let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
1053
1054                unsafe {
1055                    ffi::launch_fp8_indexed_moe_gemm_bf16(
1056                        input_ptr as *const _,
1057                        weights_ptr as *const _,
1058                        scales_ptr as *const _,
1059                        indices_ptr as *const _,
1060                        output_ptr as *mut _,
1061                        num_tokens as i32,
1062                        topk as i32,
1063                        num_experts as i32,
1064                        n as i32,
1065                        k as i32,
1066                        scale_row_stride,
1067                        block_size_y,
1068                        block_size_x,
1069                        input_has_topk_dim,
1070                        dev.cuda_stream().cu_stream(),
1071                    )
1072                };
1073            }
1074
1075            let output_storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
1076            Ok(Tensor::from((
1077                Storage::Cuda(output_storage),
1078                candle_core::Shape::from_dims(&[num_tokens, topk, n]),
1079            )))
1080        }
1081        other => candle_core::bail!(
1082            "Unsupported input dtype for FP8 indexed MoE GEMM: {:?}",
1083            other
1084        ),
1085    }
1086}
1087
1088#[cfg(test)]
1089#[allow(unused_imports)]
1090mod tests {
1091    use candle_core::{DType, Device, Result, Tensor};
1092    use candle_nn::{Linear, Module};
1093    use half::bf16;
1094    use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
1095
1096    use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors};
1097
1098    #[test]
1099    fn test_fp8_blockwise_dequant() -> Result<()> {
1100        let dev = &Device::Cpu;
1101        let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
1102        let weight_block_size = vec![2, 2];
1103        let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1104
1105        let dequant =
1106            ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
1107
1108        let res = dequant.to_vec2::<f32>()?;
1109        assert_eq!(
1110            res,
1111            vec![
1112                vec![0., 0., 1., 1., 2.],
1113                vec![0., 0., 1., 1., 2.],
1114                vec![3., 3., 4., 4., 5.],
1115                vec![3., 3., 4., 4., 5.],
1116                vec![6., 6., 7., 7., 8.],
1117            ]
1118        );
1119
1120        Ok(())
1121    }
1122
1123    #[cfg(feature = "cuda")]
1124    #[test]
1125    fn test_fp8_blockwise_dequant_cuda() -> Result<()> {
1126        let truth = {
1127            let dev = &Device::Cpu;
1128            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
1129            let weight_block_size = vec![2, 2];
1130            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1131
1132            let dequant =
1133                ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
1134
1135            dequant.to_vec2::<f32>()?
1136        };
1137        let test = {
1138            let dev = &Device::new_cuda(0)?;
1139            // Create FP8 weight by first creating on CPU then moving to CUDA
1140            let weight_cpu = Tensor::ones((5, 5), DType::F8E4M3, &Device::Cpu)?;
1141            let weight = weight_cpu.to_device(dev)?;
1142            let weight_block_size = vec![2, 2];
1143            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1144
1145            let dequant =
1146                ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
1147
1148            dequant.to_vec2::<f32>()?
1149        };
1150
1151        assert_eq!(test, truth);
1152        assert_eq!(
1153            test,
1154            vec![
1155                vec![0., 0., 1., 1., 2.],
1156                vec![0., 0., 1., 1., 2.],
1157                vec![3., 3., 4., 4., 5.],
1158                vec![3., 3., 4., 4., 5.],
1159                vec![6., 6., 7., 7., 8.],
1160            ]
1161        );
1162
1163        Ok(())
1164    }
1165
1166    #[test]
1167    fn test_fp8_blockwise_dequant_bf16() -> Result<()> {
1168        let dev = &Device::Cpu;
1169        let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
1170        let weight_block_size = vec![2, 2];
1171        let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1172
1173        let dequant =
1174            ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::BF16)?;
1175
1176        let res = dequant.to_vec2::<bf16>()?;
1177        assert_eq!(
1178            res,
1179            vec![
1180                vec![
1181                    bf16::from_f32(0.),
1182                    bf16::from_f32(0.),
1183                    bf16::from_f32(1.),
1184                    bf16::from_f32(1.),
1185                    bf16::from_f32(2.)
1186                ],
1187                vec![
1188                    bf16::from_f32(0.),
1189                    bf16::from_f32(0.),
1190                    bf16::from_f32(1.),
1191                    bf16::from_f32(1.),
1192                    bf16::from_f32(2.)
1193                ],
1194                vec![
1195                    bf16::from_f32(3.),
1196                    bf16::from_f32(3.),
1197                    bf16::from_f32(4.),
1198                    bf16::from_f32(4.),
1199                    bf16::from_f32(5.)
1200                ],
1201                vec![
1202                    bf16::from_f32(3.),
1203                    bf16::from_f32(3.),
1204                    bf16::from_f32(4.),
1205                    bf16::from_f32(4.),
1206                    bf16::from_f32(5.)
1207                ],
1208                vec![
1209                    bf16::from_f32(6.),
1210                    bf16::from_f32(6.),
1211                    bf16::from_f32(7.),
1212                    bf16::from_f32(7.),
1213                    bf16::from_f32(8.)
1214                ],
1215            ]
1216        );
1217
1218        Ok(())
1219    }
1220
1221    #[cfg(feature = "cuda")]
1222    #[test]
1223    fn test_fp8_blockwise_dequant_cuda_bf16() -> Result<()> {
1224        let truth = {
1225            let dev = &Device::Cpu;
1226            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
1227            let weight_block_size = vec![2, 2];
1228            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1229
1230            let dequant = ops::fp8_blockwise_dequantize(
1231                &weight,
1232                &inv_scales,
1233                weight_block_size,
1234                DType::BF16,
1235            )?;
1236
1237            dequant.to_vec2::<bf16>()?
1238        };
1239        let test = {
1240            let dev = &Device::new_cuda(0)?;
1241            // Create FP8 weight by first creating on CPU then moving to CUDA
1242            let weight_cpu = Tensor::ones((5, 5), DType::F8E4M3, &Device::Cpu)?;
1243            let weight = weight_cpu.to_device(dev)?;
1244            let weight_block_size = vec![2, 2];
1245            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
1246
1247            let dequant = ops::fp8_blockwise_dequantize(
1248                &weight,
1249                &inv_scales,
1250                weight_block_size,
1251                DType::BF16,
1252            )?;
1253
1254            dequant.to_vec2::<bf16>()?
1255        };
1256
1257        assert_eq!(test, truth);
1258        assert_eq!(
1259            test,
1260            vec![
1261                vec![
1262                    bf16::from_f32(0.),
1263                    bf16::from_f32(0.),
1264                    bf16::from_f32(1.),
1265                    bf16::from_f32(1.),
1266                    bf16::from_f32(2.)
1267                ],
1268                vec![
1269                    bf16::from_f32(0.),
1270                    bf16::from_f32(0.),
1271                    bf16::from_f32(1.),
1272                    bf16::from_f32(1.),
1273                    bf16::from_f32(2.)
1274                ],
1275                vec![
1276                    bf16::from_f32(3.),
1277                    bf16::from_f32(3.),
1278                    bf16::from_f32(4.),
1279                    bf16::from_f32(4.),
1280                    bf16::from_f32(5.)
1281                ],
1282                vec![
1283                    bf16::from_f32(3.),
1284                    bf16::from_f32(3.),
1285                    bf16::from_f32(4.),
1286                    bf16::from_f32(4.),
1287                    bf16::from_f32(5.)
1288                ],
1289                vec![
1290                    bf16::from_f32(6.),
1291                    bf16::from_f32(6.),
1292                    bf16::from_f32(7.),
1293                    bf16::from_f32(7.),
1294                    bf16::from_f32(8.)
1295                ],
1296            ]
1297        );
1298
1299        Ok(())
1300    }
1301
1302    #[cfg(feature = "cuda")]
1303    #[test]
1304    fn test_fp8_blockwise_quant_dequant_roundtrip() -> Result<()> {
1305        let dev = &Device::new_cuda(0)?;
1306
1307        // Create test input
1308        let input = Tensor::randn(0f32, 2f32, (8, 8), dev)?;
1309        let weight_block_size = vec![4, 4];
1310
1311        // Quantize
1312        let (quantized, scales) = ops::fp8_blockwise_quantize(&input, weight_block_size.clone())?;
1313
1314        // Verify shapes
1315        assert_eq!(quantized.shape(), input.shape());
1316        assert_eq!(scales.dims2()?, (2, 2)); // 8/4 = 2 blocks in each dimension
1317
1318        // Dequantize
1319        let dequantized =
1320            ops::fp8_blockwise_dequantize(&quantized, &scales, weight_block_size, input.dtype())?;
1321
1322        // Check that shapes match
1323        assert_eq!(dequantized.shape(), input.shape());
1324
1325        // The values won't be exactly the same due to quantization loss,
1326        // but they should be reasonably close
1327        let input_vec = input.to_vec2::<f32>()?;
1328        let dequant_vec = dequantized.to_vec2::<f32>()?;
1329
1330        let mut max_error = 0f32;
1331        for (row_in, row_out) in input_vec.iter().zip(dequant_vec.iter()) {
1332            for (val_in, val_out) in row_in.iter().zip(row_out.iter()) {
1333                let error = (val_in - val_out).abs();
1334                max_error = max_error.max(error);
1335            }
1336        }
1337
1338        // FP8 E4M3 has limited precision, so we expect some error
1339        // but it should be reasonable
1340        assert!(max_error < 0.16, "Max error {} is too large", max_error);
1341
1342        Ok(())
1343    }
1344
1345    #[cfg(feature = "cuda")]
1346    #[test]
1347    fn test_blockwise_fp8_gemm() -> Result<()> {
1348        let dev = Device::cuda_if_available(0)?;
1349
1350        let api = ApiBuilder::new().with_progress(true).build().unwrap();
1351        let api = api.repo(Repo::with_revision(
1352            "EricB/mistralrs_tests".to_string(),
1353            RepoType::Model,
1354            "main".to_string(),
1355        ));
1356
1357        let filename = api.get("test_fp8.safetensors").unwrap();
1358        let vb = unsafe { MmapedSafetensors::new(filename)? };
1359
1360        let weight = vb.load("weight", &dev, None)?;
1361        assert_eq!((7168, 2048), weight.dims2()?);
1362        assert_eq!(DType::F8E4M3, weight.dtype());
1363
1364        let scale = vb.load("scale", &dev, None)?;
1365        assert_eq!((56, 16), scale.dims2()?);
1366        assert_eq!(DType::F32, scale.dtype());
1367
1368        let weight_block_size = vec![128, 128];
1369
1370        // in dim is 2048.
1371        let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?;
1372
1373        let truth = {
1374            let weight_dq =
1375                ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?;
1376
1377            let lin_dq = Linear::new(weight_dq, None);
1378            lin_dq.forward(&xs)?
1379        };
1380
1381        // TODO: will be adding real blockwise fp8 gemm shortly ;)
1382        assert_eq!((32, 7168), truth.dims2()?);
1383
1384        Ok(())
1385    }
1386}