mistralrs_quant/blockwise_fp8/
ops.rs

1use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5struct Fp8BlockwiseDequantize {
6    weight_block_size: Vec<usize>,
7    out_ty: DType,
8}
9
10impl Fp8BlockwiseDequantize {
11    fn dispatch_dequant_blockwise<T: WithDType>(
12        &self,
13        weight: &[F8E4M3],
14        scale: &[f32],
15        weight_l: &candle_core::Layout,
16        scale_l: &candle_core::Layout,
17    ) -> candle_core::Result<Vec<T>> {
18        let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]);
19        let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]);
20
21        let res = vec![T::zero(); weight.len()];
22
23        (0..grid_y).into_par_iter().for_each(|y| {
24            (0..grid_x).into_par_iter().for_each(|x| {
25                let res_ptr = res.as_ptr() as *mut T;
26
27                let scale = scale[y * scale_l.stride()[0] + x];
28
29                let start_y = y * self.weight_block_size[0];
30                let end_y = start_y + self.weight_block_size[0];
31
32                let start_x = x * self.weight_block_size[1];
33                let end_x = start_x + self.weight_block_size[1];
34
35                for weight_y in start_y..end_y {
36                    if weight_y >= weight_l.dims()[0] {
37                        break;
38                    }
39
40                    let row_offset = weight_y * weight_l.stride()[0];
41                    for weight_x in start_x..end_x {
42                        if weight_x >= weight_l.dims()[1] {
43                            break;
44                        }
45
46                        let weight_pos = row_offset + weight_x;
47
48                        // SAFETY: We know each thread will only update indepedant values!
49                        unsafe {
50                            *res_ptr.wrapping_add(weight_pos) =
51                                T::from_f64((weight[weight_pos].to_f32() * scale) as f64);
52                        }
53                    }
54                }
55            });
56        });
57
58        Ok(res)
59    }
60}
61
62impl CustomOp2 for Fp8BlockwiseDequantize {
63    fn name(&self) -> &'static str {
64        "fp8-blockwise-dequantize"
65    }
66
67    fn cpu_fwd(
68        &self,
69        scale_s: &candle_core::CpuStorage,
70        scale_l: &candle_core::Layout,
71        weight_s: &candle_core::CpuStorage,
72        weight_l: &candle_core::Layout,
73    ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
74        let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
75            candle_core::bail!("Expected F8E4M3 weight!");
76        };
77        let candle_core::CpuStorage::F32(scale) = scale_s else {
78            candle_core::bail!("Expected F8E4M3 weight!");
79        };
80        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
81            candle_core::bail!("Expected weight to have start offset 0, continuous");
82        }
83        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
84            candle_core::bail!("Expected scales to have start offset 0, continuous");
85        }
86        if weight_l.dims().len() != 2 {
87            candle_core::bail!("Expected weight to be rank 2");
88        }
89        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
90            candle_core::bail!("Expected scale to be rank 2");
91        }
92
93        match self.out_ty {
94            DType::F32 => Ok((
95                CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
96                weight_l.shape().clone(),
97            )),
98            DType::BF16 => Ok((
99                CpuStorage::BF16(
100                    self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?,
101                ),
102                weight_l.shape().clone(),
103            )),
104            DType::F16 => Ok((
105                CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
106                weight_l.shape().clone(),
107            )),
108            other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
109        }
110    }
111
112    #[cfg(feature = "cuda")]
113    fn cuda_fwd(
114        &self,
115        scale_s: &candle_core::CudaStorage,
116        scale_l: &candle_core::Layout,
117        weight_s: &candle_core::CudaStorage,
118        weight_l: &candle_core::Layout,
119    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
120        use candle_core::{
121            backend::BackendStorage,
122            cuda::{cudarc::driver::DevicePtr, WrapErr},
123            CudaStorage,
124        };
125        use half::{bf16, f16};
126
127        use crate::blockwise_fp8::ffi;
128
129        if !ffi::HAVE_BLOCKWISE_DEQUANT_KERNELS {
130            candle_core::bail!("Do not have blockwise FP8 dequant kernels.");
131        }
132
133        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
134            candle_core::bail!("Expected weight to have start offset 0, continuous");
135        }
136        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
137            candle_core::bail!("Expected scales to have start offset 0, continuous");
138        }
139        if weight_l.dims().len() != 2 {
140            candle_core::bail!("Expected weight to be rank 2");
141        }
142        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
143            candle_core::bail!("Expected scale to be rank 2");
144        }
145
146        let dev = weight_s.device();
147
148        let weight = weight_s
149            .as_cuda_slice::<F8E4M3>()?
150            .slice(weight_l.start_offset()..);
151        let scale = scale_s
152            .as_cuda_slice::<f32>()?
153            .slice(scale_l.start_offset()..);
154
155        let weight_height = weight_l.dim(0)? as i32;
156        let weight_block_size_x = self.weight_block_size[0] as i32;
157        let weight_width = weight_l.dim(1)? as i32;
158        let weight_block_size_y = self.weight_block_size[1] as i32;
159        let scale_stride = scale_l.stride()[0] as i32;
160        let weight_row_stride = weight_l.stride()[0] as i32;
161
162        let res = match self.out_ty {
163            DType::F32 => {
164                let output = weight_s
165                    .device()
166                    .alloc_zeros::<f32>(weight_l.shape().elem_count())
167                    .w()?;
168                unsafe {
169                    ffi::launch_dequant_fp8_blockwise_kernel_f32(
170                        (*weight.device_ptr()) as *const _,
171                        (*scale.device_ptr()) as *const _,
172                        (*output.device_ptr()) as *mut _,
173                        weight_height,
174                        weight_width,
175                        weight_row_stride,
176                        scale_stride,
177                        weight_block_size_y,
178                        weight_block_size_x,
179                        *dev.cu_stream(),
180                    )
181                };
182                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
183            }
184            DType::F16 => {
185                let output = weight_s
186                    .device()
187                    .alloc_zeros::<f16>(weight_l.shape().elem_count())
188                    .w()?;
189                unsafe {
190                    ffi::launch_dequant_fp8_blockwise_kernel_f16(
191                        (*weight.device_ptr()) as *const _,
192                        (*scale.device_ptr()) as *const _,
193                        (*output.device_ptr()) as *mut _,
194                        weight_height,
195                        weight_width,
196                        weight_row_stride,
197                        scale_stride,
198                        weight_block_size_y,
199                        weight_block_size_x,
200                        *dev.cu_stream(),
201                    )
202                };
203                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
204            }
205            DType::BF16 => {
206                let output = weight_s
207                    .device()
208                    .alloc_zeros::<bf16>(weight_l.shape().elem_count())
209                    .w()?;
210                unsafe {
211                    ffi::launch_dequant_fp8_blockwise_kernel_bf16(
212                        (*weight.device_ptr()) as *const _,
213                        (*scale.device_ptr()) as *const _,
214                        (*output.device_ptr()) as *mut _,
215                        weight_height,
216                        weight_width,
217                        weight_row_stride,
218                        scale_stride,
219                        weight_block_size_y,
220                        weight_block_size_x,
221                        *dev.cu_stream(),
222                    )
223                };
224                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
225            }
226            other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
227        };
228
229        Ok((res, weight_l.shape().clone()))
230    }
231}
232
233/// FP8 blockwise dequantize.
234/// - Expects weight to be fp8
235/// - Expects inv_scales to be f32
236/// - weight * inv_scale = dequantized
237pub fn fp8_blockwise_dequantize(
238    weight: &Tensor,
239    inv_scales: &Tensor,
240    weight_block_size: Vec<usize>,
241    out_ty: DType,
242) -> Result<Tensor> {
243    inv_scales.apply_op2_no_bwd(
244        weight,
245        &Fp8BlockwiseDequantize {
246            weight_block_size,
247            out_ty,
248        },
249    )
250}
251
252#[cfg(test)]
253#[allow(unused_imports)]
254mod tests {
255    use candle_core::{DType, Device, Result, Tensor};
256    use candle_nn::{Linear, Module};
257    use half::bf16;
258    use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
259
260    use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors};
261
262    #[test]
263    fn test_fp8_blockwise_dequant() -> Result<()> {
264        let dev = &Device::Cpu;
265        let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
266        let weight_block_size = vec![2, 2];
267        let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
268
269        let dequant =
270            ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
271
272        let res = dequant.to_vec2::<f32>()?;
273        assert_eq!(
274            res,
275            vec![
276                vec![0., 0., 1., 1., 2.],
277                vec![0., 0., 1., 1., 2.],
278                vec![3., 3., 4., 4., 5.],
279                vec![3., 3., 4., 4., 5.],
280                vec![6., 6., 7., 7., 8.],
281            ]
282        );
283
284        Ok(())
285    }
286
287    #[cfg(feature = "cuda")]
288    #[test]
289    fn test_fp8_blockwise_dequant_cuda() -> Result<()> {
290        let truth = {
291            let dev = &Device::Cpu;
292            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
293            let weight_block_size = vec![2, 2];
294            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
295
296            let dequant =
297                ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
298
299            dequant.to_vec2::<f32>()?
300        };
301        let test = {
302            let dev = &Device::new_cuda(0)?;
303            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
304            let weight_block_size = vec![2, 2];
305            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
306
307            let dequant =
308                ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
309
310            dequant.to_vec2::<f32>()?
311        };
312
313        assert_eq!(test, truth);
314        assert_eq!(
315            test,
316            vec![
317                vec![0., 0., 1., 1., 2.],
318                vec![0., 0., 1., 1., 2.],
319                vec![3., 3., 4., 4., 5.],
320                vec![3., 3., 4., 4., 5.],
321                vec![6., 6., 7., 7., 8.],
322            ]
323        );
324
325        Ok(())
326    }
327
328    #[test]
329    fn test_fp8_blockwise_dequant_bf16() -> Result<()> {
330        let dev = &Device::Cpu;
331        let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
332        let weight_block_size = vec![2, 2];
333        let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
334
335        let dequant =
336            ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::BF16)?;
337
338        let res = dequant.to_vec2::<bf16>()?;
339        assert_eq!(
340            res,
341            vec![
342                vec![
343                    bf16::from_f32(0.),
344                    bf16::from_f32(0.),
345                    bf16::from_f32(1.),
346                    bf16::from_f32(1.),
347                    bf16::from_f32(2.)
348                ],
349                vec![
350                    bf16::from_f32(0.),
351                    bf16::from_f32(0.),
352                    bf16::from_f32(1.),
353                    bf16::from_f32(1.),
354                    bf16::from_f32(2.)
355                ],
356                vec![
357                    bf16::from_f32(3.),
358                    bf16::from_f32(3.),
359                    bf16::from_f32(4.),
360                    bf16::from_f32(4.),
361                    bf16::from_f32(5.)
362                ],
363                vec![
364                    bf16::from_f32(3.),
365                    bf16::from_f32(3.),
366                    bf16::from_f32(4.),
367                    bf16::from_f32(4.),
368                    bf16::from_f32(5.)
369                ],
370                vec![
371                    bf16::from_f32(6.),
372                    bf16::from_f32(6.),
373                    bf16::from_f32(7.),
374                    bf16::from_f32(7.),
375                    bf16::from_f32(8.)
376                ],
377            ]
378        );
379
380        Ok(())
381    }
382
383    #[cfg(feature = "cuda")]
384    #[test]
385    fn test_fp8_blockwise_dequant_cuda_bf16() -> Result<()> {
386        let truth = {
387            let dev = &Device::Cpu;
388            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
389            let weight_block_size = vec![2, 2];
390            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
391
392            let dequant = ops::fp8_blockwise_dequantize(
393                &weight,
394                &inv_scales,
395                weight_block_size,
396                DType::BF16,
397            )?;
398
399            dequant.to_vec2::<bf16>()?
400        };
401        let test = {
402            let dev = &Device::new_cuda(0)?;
403            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
404            let weight_block_size = vec![2, 2];
405            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
406
407            let dequant = ops::fp8_blockwise_dequantize(
408                &weight,
409                &inv_scales,
410                weight_block_size,
411                DType::BF16,
412            )?;
413
414            dequant.to_vec2::<bf16>()?
415        };
416
417        assert_eq!(test, truth);
418        assert_eq!(
419            test,
420            vec![
421                vec![
422                    bf16::from_f32(0.),
423                    bf16::from_f32(0.),
424                    bf16::from_f32(1.),
425                    bf16::from_f32(1.),
426                    bf16::from_f32(2.)
427                ],
428                vec![
429                    bf16::from_f32(0.),
430                    bf16::from_f32(0.),
431                    bf16::from_f32(1.),
432                    bf16::from_f32(1.),
433                    bf16::from_f32(2.)
434                ],
435                vec![
436                    bf16::from_f32(3.),
437                    bf16::from_f32(3.),
438                    bf16::from_f32(4.),
439                    bf16::from_f32(4.),
440                    bf16::from_f32(5.)
441                ],
442                vec![
443                    bf16::from_f32(3.),
444                    bf16::from_f32(3.),
445                    bf16::from_f32(4.),
446                    bf16::from_f32(4.),
447                    bf16::from_f32(5.)
448                ],
449                vec![
450                    bf16::from_f32(6.),
451                    bf16::from_f32(6.),
452                    bf16::from_f32(7.),
453                    bf16::from_f32(7.),
454                    bf16::from_f32(8.)
455                ],
456            ]
457        );
458
459        Ok(())
460    }
461
462    #[cfg(feature = "cuda")]
463    #[test]
464    fn test_blockwise_fp8_gemm() -> Result<()> {
465        let dev = Device::cuda_if_available(0)?;
466
467        let api = ApiBuilder::new().with_progress(true).build().unwrap();
468        let api = api.repo(Repo::with_revision(
469            "EricB/mistralrs_tests".to_string(),
470            RepoType::Model,
471            "main".to_string(),
472        ));
473
474        let filename = api.get("test_fp8.safetensors").unwrap();
475        let vb = unsafe { MmapedSafetensors::new(filename)? };
476
477        let weight = vb.load("weight", &dev, None)?;
478        assert_eq!((7168, 2048), weight.dims2()?);
479        assert_eq!(DType::F8E4M3, weight.dtype());
480
481        let scale = vb.load("scale", &dev, None)?;
482        assert_eq!((56, 16), scale.dims2()?);
483        assert_eq!(DType::F32, scale.dtype());
484
485        let weight_block_size = vec![128, 128];
486
487        // in dim is 2048.
488        let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?;
489
490        let truth = {
491            let weight_dq =
492                ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?;
493
494            let lin_dq = Linear::new(weight_dq, None);
495            lin_dq.forward(&xs)?
496        };
497
498        // TODO: will be adding real blockwise fp8 gemm shortly ;)
499        assert_eq!((32, 7168), truth.dims2()?);
500
501        Ok(())
502    }
503}