mistralrs_quant/blockwise_fp8/
ops.rs

1use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5struct Fp8BlockwiseDequantize {
6    weight_block_size: Vec<usize>,
7    out_ty: DType,
8}
9
10impl Fp8BlockwiseDequantize {
11    fn dispatch_dequant_blockwise<T: WithDType>(
12        &self,
13        weight: &[F8E4M3],
14        scale: &[f32],
15        weight_l: &candle_core::Layout,
16        scale_l: &candle_core::Layout,
17    ) -> candle_core::Result<Vec<T>> {
18        let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]);
19        let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]);
20
21        let res = vec![T::zero(); weight.len()];
22
23        (0..grid_y).into_par_iter().for_each(|y| {
24            (0..grid_x).into_par_iter().for_each(|x| {
25                let res_ptr = res.as_ptr() as *mut T;
26
27                let scale = scale[y * scale_l.stride()[0] + x];
28
29                let start_y = y * self.weight_block_size[0];
30                let end_y = start_y + self.weight_block_size[0];
31
32                let start_x = x * self.weight_block_size[1];
33                let end_x = start_x + self.weight_block_size[1];
34
35                for weight_y in start_y..end_y {
36                    if weight_y >= weight_l.dims()[0] {
37                        break;
38                    }
39
40                    let row_offset = weight_y * weight_l.stride()[0];
41                    for weight_x in start_x..end_x {
42                        if weight_x >= weight_l.dims()[1] {
43                            break;
44                        }
45
46                        let weight_pos = row_offset + weight_x;
47
48                        // SAFETY: We know each thread will only update indepedant values!
49                        unsafe {
50                            *res_ptr.wrapping_add(weight_pos) =
51                                T::from_f64((weight[weight_pos].to_f32() * scale) as f64);
52                        }
53                    }
54                }
55            });
56        });
57
58        Ok(res)
59    }
60}
61
62impl CustomOp2 for Fp8BlockwiseDequantize {
63    fn name(&self) -> &'static str {
64        "fp8-blockwise-dequantize"
65    }
66
67    fn cpu_fwd(
68        &self,
69        scale_s: &candle_core::CpuStorage,
70        scale_l: &candle_core::Layout,
71        weight_s: &candle_core::CpuStorage,
72        weight_l: &candle_core::Layout,
73    ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
74        let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
75            candle_core::bail!("Expected F8E4M3 weight!");
76        };
77        let candle_core::CpuStorage::F32(scale) = scale_s else {
78            candle_core::bail!("Expected F8E4M3 weight!");
79        };
80        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
81            candle_core::bail!("Expected weight to have start offset 0, continuous");
82        }
83        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
84            candle_core::bail!("Expected scales to have start offset 0, continuous");
85        }
86        if weight_l.dims().len() != 2 {
87            candle_core::bail!("Expected weight to be rank 2");
88        }
89        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
90            candle_core::bail!("Expected scale to be rank 2");
91        }
92
93        match self.out_ty {
94            DType::F32 => Ok((
95                CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
96                weight_l.shape().clone(),
97            )),
98            DType::BF16 => Ok((
99                CpuStorage::BF16(
100                    self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?,
101                ),
102                weight_l.shape().clone(),
103            )),
104            DType::F16 => Ok((
105                CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
106                weight_l.shape().clone(),
107            )),
108            other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
109        }
110    }
111
112    #[cfg(feature = "cuda")]
113    fn cuda_fwd(
114        &self,
115        scale_s: &candle_core::CudaStorage,
116        scale_l: &candle_core::Layout,
117        weight_s: &candle_core::CudaStorage,
118        weight_l: &candle_core::Layout,
119    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
120        use candle_core::{
121            backend::BackendStorage,
122            cuda::{cudarc::driver::DevicePtr, WrapErr},
123            CudaStorage,
124        };
125        use half::{bf16, f16};
126
127        use crate::blockwise_fp8::ffi;
128
129        if !ffi::HAVE_BLOCKWISE_DEQUANT_KERNELS {
130            candle_core::bail!("Do not have blockwise FP8 dequant kernels.");
131        }
132
133        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
134            candle_core::bail!("Expected weight to have start offset 0, continuous");
135        }
136        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
137            candle_core::bail!("Expected scales to have start offset 0, continuous");
138        }
139        if weight_l.dims().len() != 2 {
140            candle_core::bail!("Expected weight to be rank 2");
141        }
142        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
143            candle_core::bail!("Expected scale to be rank 2");
144        }
145
146        let dev = weight_s.device();
147
148        let weight = weight_s
149            .as_cuda_slice::<F8E4M3>()?
150            .slice(weight_l.start_offset()..);
151        let scale = scale_s
152            .as_cuda_slice::<f32>()?
153            .slice(scale_l.start_offset()..);
154
155        let weight_height = weight_l.dim(0)? as i32;
156        let weight_block_size_x = self.weight_block_size[0] as i32;
157        let weight_width = weight_l.dim(1)? as i32;
158        let weight_block_size_y = self.weight_block_size[1] as i32;
159        let scale_stride = scale_l.stride()[0] as i32;
160        let weight_row_stride = weight_l.stride()[0] as i32;
161
162        let res = match self.out_ty {
163            DType::F32 => {
164                let output = weight_s
165                    .device()
166                    .alloc_zeros::<f32>(weight_l.shape().elem_count())
167                    .w()?;
168                unsafe {
169                    ffi::launch_dequant_fp8_blockwise_kernel_f32(
170                        (*weight.device_ptr()) as *const _,
171                        (*scale.device_ptr()) as *const _,
172                        (*output.device_ptr()) as *mut _,
173                        weight_height,
174                        weight_width,
175                        weight_row_stride,
176                        scale_stride,
177                        weight_block_size_y,
178                        weight_block_size_x,
179                        *dev.cu_stream(),
180                    )
181                };
182                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
183            }
184            DType::F16 => {
185                let output = weight_s
186                    .device()
187                    .alloc_zeros::<f16>(weight_l.shape().elem_count())
188                    .w()?;
189                unsafe {
190                    ffi::launch_dequant_fp8_blockwise_kernel_f16(
191                        (*weight.device_ptr()) as *const _,
192                        (*scale.device_ptr()) as *const _,
193                        (*output.device_ptr()) as *mut _,
194                        weight_height,
195                        weight_width,
196                        weight_row_stride,
197                        scale_stride,
198                        weight_block_size_y,
199                        weight_block_size_x,
200                        *dev.cu_stream(),
201                    )
202                };
203                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
204            }
205            DType::BF16 => {
206                let output = weight_s
207                    .device()
208                    .alloc_zeros::<bf16>(weight_l.shape().elem_count())
209                    .w()?;
210                unsafe {
211                    ffi::launch_dequant_fp8_blockwise_kernel_bf16(
212                        (*weight.device_ptr()) as *const _,
213                        (*scale.device_ptr()) as *const _,
214                        (*output.device_ptr()) as *mut _,
215                        weight_height,
216                        weight_width,
217                        weight_row_stride,
218                        scale_stride,
219                        weight_block_size_y,
220                        weight_block_size_x,
221                        *dev.cu_stream(),
222                    )
223                };
224                CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
225            }
226            other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
227        };
228
229        Ok((res, weight_l.shape().clone()))
230    }
231
232    #[cfg(feature = "metal")]
233    fn metal_fwd(
234        &self,
235        scale_s: &candle_core::MetalStorage,
236        scale_l: &candle_core::Layout,
237        weight_s: &candle_core::MetalStorage,
238        weight_l: &candle_core::Layout,
239    ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
240        use candle_core::backend::BackendStorage;
241
242        if weight_l.start_offset() != 0
243            || !weight_l.is_contiguous()
244            || weight_s.dtype() != DType::F8E4M3
245        {
246            candle_core::bail!("Expected f8e4m3 weight to have start offset 0, continuous");
247        }
248        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() || scale_s.dtype() != DType::F32
249        {
250            candle_core::bail!("Expected f32 scales to have start offset 0, continuous");
251        }
252        if weight_l.dims().len() != 2 {
253            candle_core::bail!("Expected weight to be rank 2");
254        }
255        if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
256            candle_core::bail!("Expected scale to be rank 2");
257        }
258
259        let command_buffer = weight_s.device().command_buffer()?;
260        command_buffer.set_label("dequant-blockwise-fp8");
261
262        let device = weight_s.device();
263
264        let out_shape = weight_l.shape().clone();
265
266        let output = device.new_buffer(
267            out_shape.elem_count(),
268            weight_s.dtype(),
269            "dequant-blockwise-fp8",
270        )?;
271
272        let weight_height = weight_l.dim(0)? as u32;
273        let weight_block_size_x = self.weight_block_size[0] as u32;
274        let weight_width = weight_l.dim(1)? as u32;
275        let weight_block_size_y = self.weight_block_size[1] as u32;
276        let scale_stride = scale_l.stride()[0] as u32;
277        let weight_row_stride = weight_l.stride()[0] as u32;
278
279        crate::metal_kernels::call_dequant_blockwise_fp8(
280            device.device(),
281            &command_buffer,
282            &crate::metal_kernels::Kernels::new(),
283            self.out_ty,
284            weight_s.buffer(),
285            scale_s.buffer(),
286            &output,
287            weight_height,
288            weight_width,
289            weight_row_stride,
290            scale_stride,
291            weight_block_size_y,
292            weight_block_size_x,
293        )
294        .map_err(candle_core::Error::wrap)?;
295
296        let newstorage = candle_core::MetalStorage::new(
297            output,
298            device.clone(),
299            out_shape.elem_count(),
300            self.out_ty,
301        );
302        Ok((newstorage, out_shape))
303    }
304}
305
306/// FP8 blockwise dequantize.
307/// - Expects weight to be fp8
308/// - Expects inv_scales to be f32
309/// - weight * inv_scale = dequantized
310pub fn fp8_blockwise_dequantize(
311    weight: &Tensor,
312    inv_scales: &Tensor,
313    weight_block_size: Vec<usize>,
314    out_ty: DType,
315) -> Result<Tensor> {
316    inv_scales.apply_op2_no_bwd(
317        weight,
318        &Fp8BlockwiseDequantize {
319            weight_block_size,
320            out_ty,
321        },
322    )
323}
324
325#[cfg(test)]
326#[allow(unused_imports)]
327mod tests {
328    use candle_core::{DType, Device, Result, Tensor};
329    use candle_nn::{Linear, Module};
330    use half::bf16;
331    use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
332
333    use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors};
334
335    #[test]
336    fn test_fp8_blockwise_dequant() -> Result<()> {
337        let dev = &Device::Cpu;
338        let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
339        let weight_block_size = vec![2, 2];
340        let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
341
342        let dequant =
343            ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
344
345        let res = dequant.to_vec2::<f32>()?;
346        assert_eq!(
347            res,
348            vec![
349                vec![0., 0., 1., 1., 2.],
350                vec![0., 0., 1., 1., 2.],
351                vec![3., 3., 4., 4., 5.],
352                vec![3., 3., 4., 4., 5.],
353                vec![6., 6., 7., 7., 8.],
354            ]
355        );
356
357        Ok(())
358    }
359
360    #[cfg(feature = "cuda")]
361    #[test]
362    fn test_fp8_blockwise_dequant_cuda() -> Result<()> {
363        let truth = {
364            let dev = &Device::Cpu;
365            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
366            let weight_block_size = vec![2, 2];
367            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
368
369            let dequant =
370                ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
371
372            dequant.to_vec2::<f32>()?
373        };
374        let test = {
375            let dev = &Device::new_cuda(0)?;
376            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
377            let weight_block_size = vec![2, 2];
378            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
379
380            let dequant =
381                ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
382
383            dequant.to_vec2::<f32>()?
384        };
385
386        assert_eq!(test, truth);
387        assert_eq!(
388            test,
389            vec![
390                vec![0., 0., 1., 1., 2.],
391                vec![0., 0., 1., 1., 2.],
392                vec![3., 3., 4., 4., 5.],
393                vec![3., 3., 4., 4., 5.],
394                vec![6., 6., 7., 7., 8.],
395            ]
396        );
397
398        Ok(())
399    }
400
401    #[test]
402    fn test_fp8_blockwise_dequant_bf16() -> Result<()> {
403        let dev = &Device::Cpu;
404        let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
405        let weight_block_size = vec![2, 2];
406        let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
407
408        let dequant =
409            ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::BF16)?;
410
411        let res = dequant.to_vec2::<bf16>()?;
412        assert_eq!(
413            res,
414            vec![
415                vec![
416                    bf16::from_f32(0.),
417                    bf16::from_f32(0.),
418                    bf16::from_f32(1.),
419                    bf16::from_f32(1.),
420                    bf16::from_f32(2.)
421                ],
422                vec![
423                    bf16::from_f32(0.),
424                    bf16::from_f32(0.),
425                    bf16::from_f32(1.),
426                    bf16::from_f32(1.),
427                    bf16::from_f32(2.)
428                ],
429                vec![
430                    bf16::from_f32(3.),
431                    bf16::from_f32(3.),
432                    bf16::from_f32(4.),
433                    bf16::from_f32(4.),
434                    bf16::from_f32(5.)
435                ],
436                vec![
437                    bf16::from_f32(3.),
438                    bf16::from_f32(3.),
439                    bf16::from_f32(4.),
440                    bf16::from_f32(4.),
441                    bf16::from_f32(5.)
442                ],
443                vec![
444                    bf16::from_f32(6.),
445                    bf16::from_f32(6.),
446                    bf16::from_f32(7.),
447                    bf16::from_f32(7.),
448                    bf16::from_f32(8.)
449                ],
450            ]
451        );
452
453        Ok(())
454    }
455
456    #[cfg(feature = "cuda")]
457    #[test]
458    fn test_fp8_blockwise_dequant_cuda_bf16() -> Result<()> {
459        let truth = {
460            let dev = &Device::Cpu;
461            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
462            let weight_block_size = vec![2, 2];
463            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
464
465            let dequant = ops::fp8_blockwise_dequantize(
466                &weight,
467                &inv_scales,
468                weight_block_size,
469                DType::BF16,
470            )?;
471
472            dequant.to_vec2::<bf16>()?
473        };
474        let test = {
475            let dev = &Device::new_cuda(0)?;
476            let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
477            let weight_block_size = vec![2, 2];
478            let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
479
480            let dequant = ops::fp8_blockwise_dequantize(
481                &weight,
482                &inv_scales,
483                weight_block_size,
484                DType::BF16,
485            )?;
486
487            dequant.to_vec2::<bf16>()?
488        };
489
490        assert_eq!(test, truth);
491        assert_eq!(
492            test,
493            vec![
494                vec![
495                    bf16::from_f32(0.),
496                    bf16::from_f32(0.),
497                    bf16::from_f32(1.),
498                    bf16::from_f32(1.),
499                    bf16::from_f32(2.)
500                ],
501                vec![
502                    bf16::from_f32(0.),
503                    bf16::from_f32(0.),
504                    bf16::from_f32(1.),
505                    bf16::from_f32(1.),
506                    bf16::from_f32(2.)
507                ],
508                vec![
509                    bf16::from_f32(3.),
510                    bf16::from_f32(3.),
511                    bf16::from_f32(4.),
512                    bf16::from_f32(4.),
513                    bf16::from_f32(5.)
514                ],
515                vec![
516                    bf16::from_f32(3.),
517                    bf16::from_f32(3.),
518                    bf16::from_f32(4.),
519                    bf16::from_f32(4.),
520                    bf16::from_f32(5.)
521                ],
522                vec![
523                    bf16::from_f32(6.),
524                    bf16::from_f32(6.),
525                    bf16::from_f32(7.),
526                    bf16::from_f32(7.),
527                    bf16::from_f32(8.)
528                ],
529            ]
530        );
531
532        Ok(())
533    }
534
535    #[cfg(feature = "cuda")]
536    #[test]
537    fn test_blockwise_fp8_gemm() -> Result<()> {
538        let dev = Device::cuda_if_available(0)?;
539
540        let api = ApiBuilder::new().with_progress(true).build().unwrap();
541        let api = api.repo(Repo::with_revision(
542            "EricB/mistralrs_tests".to_string(),
543            RepoType::Model,
544            "main".to_string(),
545        ));
546
547        let filename = api.get("test_fp8.safetensors").unwrap();
548        let vb = unsafe { MmapedSafetensors::new(filename)? };
549
550        let weight = vb.load("weight", &dev, None)?;
551        assert_eq!((7168, 2048), weight.dims2()?);
552        assert_eq!(DType::F8E4M3, weight.dtype());
553
554        let scale = vb.load("scale", &dev, None)?;
555        assert_eq!((56, 16), scale.dims2()?);
556        assert_eq!(DType::F32, scale.dtype());
557
558        let weight_block_size = vec![128, 128];
559
560        // in dim is 2048.
561        let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?;
562
563        let truth = {
564            let weight_dq =
565                ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?;
566
567            let lin_dq = Linear::new(weight_dq, None);
568            lin_dq.forward(&xs)?
569        };
570
571        // TODO: will be adding real blockwise fp8 gemm shortly ;)
572        assert_eq!((32, 7168), truth.dims2()?);
573
574        Ok(())
575    }
576}