mistralrs_quant/fp8/
quantize.rs

1use candle_core::{DType, Result, Tensor};
2use candle_nn::Linear;
3use float8::F8E4M3;
4
5use super::FP8Linear;
6
7pub(super) struct QuantizationResult {
8    /// Quantized tensor (f8)
9    pub(super) qw: Tensor,
10    /// Scalar, f32 tensor.
11    ///
12    /// Convert unquantized to quantized tensor as follows:
13    /// `q = x * qs`
14    pub(super) quantize_scale: Tensor,
15    /// Scalar, f32 tensor. Reciprocal of `quantize_scale`.
16    ///
17    /// Convert unquantized to quantized tensor as follows:
18    /// `x = q * dqs`
19    pub(super) dequantize_scale: Tensor,
20}
21
22impl FP8Linear {
23    pub(super) fn quantize(data: &Tensor, dtype: DType) -> Result<QuantizationResult> {
24        let data = data.to_dtype(DType::BF16)?;
25        let mut absmax = data.abs()?;
26        while !absmax.dims().is_empty() {
27            absmax = absmax.max(0)?;
28        }
29
30        let max_v = F8E4M3::MAX.to_f64();
31        let scale = (max_v / absmax)?
32            .clamp(F8E4M3::MIN.to_f32(), F8E4M3::MAX.to_f32())?
33            .to_dtype(DType::F32)?;
34        let to_cast = data.broadcast_mul(&scale.to_dtype(data.dtype())?)?;
35        let qw = if dtype == DType::F8E4M3 {
36            crate::scalar_fp8::ops::dtype_to_fp8(&to_cast)?
37        } else {
38            to_cast.to_dtype(dtype)?
39        };
40        Ok(QuantizationResult {
41            qw,
42            quantize_scale: scale.clone(),
43            dequantize_scale: scale.recip()?,
44        })
45    }
46
47    pub(super) fn dequantize(&self, dtype: DType) -> Result<Linear> {
48        let dequant_w = self
49            .lin
50            .weight()
51            .to_dtype(dtype)?
52            .broadcast_mul(&self.dequant_w_scale.to_dtype(dtype)?)?;
53        Ok(Linear::new(dequant_w, self.lin.bias().cloned()))
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    #[cfg(not(feature = "metal"))]
60    use candle_core::{
61        quantized::{GgmlDType, QTensor},
62        DType, Device, Result, Tensor,
63    };
64
65    #[cfg(not(feature = "metal"))]
66    use crate::fp8::FP8Linear;
67
68    #[cfg(not(feature = "metal"))]
69    use super::QuantizationResult;
70
71    #[test]
72    #[cfg(not(feature = "metal"))]
73    fn test_roundtrip_f8e4m3() -> Result<()> {
74        let dev = Device::cuda_if_available(0)?;
75
76        let data = Tensor::rand(0f32, 1f32, (32, 32), &dev)?;
77
78        let QuantizationResult {
79            qw,
80            quantize_scale: _,
81            dequantize_scale,
82        } = FP8Linear::quantize(&data, DType::F8E4M3)?;
83
84        let dequant = crate::scalar_fp8::ops::fp8_to_dtype(&qw, DType::F32)?
85            .broadcast_mul(&dequantize_scale)?;
86
87        let diff1 = (&data - dequant)?.abs()?.mean_all()?;
88
89        println!("{diff1}");
90
91        let q8_0 = QTensor::quantize(&data, GgmlDType::Q8_0)?.dequantize(&dev)?;
92        let diff2 = (&data - q8_0)?.abs()?.mean_all()?;
93
94        println!("{diff2}");
95        Ok(())
96    }
97
98    #[test]
99    #[cfg(feature = "cuda")]
100    fn test_cublaslt_matmul() -> Result<()> {
101        use crate::cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER};
102        let dev = Device::new_cuda(0)?;
103
104        let w = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?;
105        let mut x = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?;
106
107        // Batch matrix multiplication
108        maybe_init_cublas_lt_wrapper(x.device().clone());
109
110        let handle = CUBLASLT_CONTROLLER.get().unwrap();
111
112        let QuantizationResult {
113            qw,
114            quantize_scale: quant_scale,
115            dequantize_scale: dequant_a_scale,
116        } = FP8Linear::quantize(&w, DType::F8E4M3)?;
117
118        let mut dequant_b_scale = dequant_a_scale.clone();
119        if !matches!(x.dtype(), DType::F8E4M3) {
120            let QuantizationResult {
121                qw,
122                quantize_scale: _,
123                dequantize_scale,
124            } = FP8Linear::quantize(&x, DType::F8E4M3)?;
125            x = qw;
126            dequant_b_scale = dequantize_scale;
127        }
128
129        let a = qw;
130        let b = x;
131
132        // FP8 quantized matmul
133        let _res = handle.batch_matmul_f8(
134            &a,
135            &b,
136            &dequant_a_scale,
137            &dequant_b_scale,
138            &quant_scale,
139            None,
140            None,
141            None,
142            None,
143            None,
144        )?;
145
146        Ok(())
147    }
148}