mistralrs_quant/fp8/
quantize.rs

1use candle_core::{DType, Result, Tensor};
2use candle_nn::Linear;
3use float8::F8E4M3;
4use half::bf16;
5
6use super::FP8Linear;
7
8pub(super) struct QuantizationResult {
9    /// Quantized tensor (f8)
10    pub(super) qw: Tensor,
11    /// Scalar, f32 tensor.
12    ///
13    /// Convert unquantized to quantized tensor as follows:
14    /// `q = x * qs`
15    pub(super) quantize_scale: Tensor,
16    /// Scalar, f32 tensor. Reciprocal of `quantize_scale`.
17    ///
18    /// Convert unquantized to quantized tensor as follows:
19    /// `x = q * dqs`
20    pub(super) dequantize_scale: Tensor,
21}
22
23impl FP8Linear {
24    pub(super) fn quantize(data: &Tensor, dtype: DType) -> Result<QuantizationResult> {
25        let data = data.to_dtype(DType::BF16)?;
26        let mut absmax = data.abs()?;
27        while !absmax.dims().is_empty() {
28            absmax = absmax.max(0)?;
29        }
30
31        let max_v = F8E4M3::MAX.to_f64();
32        let scale = (max_v / absmax)?
33            .clamp(F8E4M3::MIN.to_f32(), F8E4M3::MAX.to_f32())?
34            .to_dtype(DType::F32)?;
35        let to_cast = data.broadcast_mul(&scale.to_dtype(data.dtype())?)?;
36        let qw = if data.device().is_metal() {
37            // Evil hack to allow metal shader to get the double value!
38            let transmute_data = to_cast
39                .flatten_all()?
40                .to_vec1::<bf16>()?
41                .into_iter()
42                .map(|x| x.to_f64_const().to_bits() as i64)
43                .collect::<Vec<_>>();
44            Tensor::from_vec(transmute_data, data.shape(), data.device())?.to_dtype(dtype)?
45        } else {
46            to_cast.to_dtype(dtype)?
47        };
48        Ok(QuantizationResult {
49            qw,
50            quantize_scale: scale.clone(),
51            dequantize_scale: scale.recip()?,
52        })
53    }
54
55    pub(super) fn dequantize(&self, dtype: DType) -> Result<Linear> {
56        let dequant_w = self
57            .lin
58            .weight()
59            .to_dtype(dtype)?
60            .broadcast_mul(&self.dequant_w_scale.to_dtype(dtype)?)?;
61        Ok(Linear::new(dequant_w, self.lin.bias().cloned()))
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use candle_core::{
68        quantized::{GgmlDType, QTensor},
69        DType, Device, Result, Tensor,
70    };
71
72    use crate::fp8::FP8Linear;
73
74    use super::QuantizationResult;
75
76    #[test]
77    fn test_roundtrip_f8e4m3() -> Result<()> {
78        #[cfg(not(feature = "metal"))]
79        let dev = Device::cuda_if_available(0)?;
80        #[cfg(feature = "metal")]
81        let dev = Device::new_metal(0)?;
82
83        let data = Tensor::rand(0f32, 1f32, (32, 32), &dev)?;
84
85        let QuantizationResult {
86            qw,
87            quantize_scale: _,
88            dequantize_scale,
89        } = FP8Linear::quantize(&data, DType::F8E4M3)?;
90
91        let dequant = qw.to_dtype(DType::F32)?.broadcast_mul(&dequantize_scale)?;
92
93        let diff1 = (&data - dequant)?.abs()?.mean_all()?;
94
95        println!("{diff1}");
96
97        let q8_0 = QTensor::quantize(&data, GgmlDType::Q8_0)?.dequantize(&dev)?;
98        let diff2 = (&data - q8_0)?.abs()?.mean_all()?;
99
100        println!("{diff2}");
101        Ok(())
102    }
103
104    #[test]
105    #[cfg(feature = "cuda")]
106    fn test_cublaslt_matmul() -> Result<()> {
107        use crate::cublaslt::{maybe_init_cublas_lt_wrapper, F8MatmulOutType, CUBLASLT_HANDLE};
108        let dev = Device::new_cuda(0)?;
109
110        let w = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?;
111        let mut x = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?;
112
113        // Batch matrix multiplication
114        maybe_init_cublas_lt_wrapper(x.device().clone());
115
116        let handle = CUBLASLT_HANDLE.lock().unwrap().unwrap();
117
118        let QuantizationResult {
119            qw,
120            quantize_scale: quant_scale,
121            dequantize_scale: dequant_a_scale,
122        } = FP8Linear::quantize(&w, DType::F8E4M3)?;
123
124        let mut dequant_b_scale = dequant_a_scale.clone();
125        if !matches!(x.dtype(), DType::F8E4M3) {
126            let QuantizationResult {
127                qw,
128                quantize_scale: _,
129                dequantize_scale,
130            } = FP8Linear::quantize(&x, DType::F8E4M3)?;
131            x = qw;
132            dequant_b_scale = dequantize_scale;
133        }
134
135        let a = qw;
136        let b = x;
137
138        // FP8 quantized matmul
139        let _res = handle.batch_matmul_f8(
140            &a,
141            &b,
142            &dequant_a_scale,
143            &dequant_b_scale,
144            &quant_scale,
145            None,
146            None,
147            None,
148            None,
149            None,
150            F8MatmulOutType::BF16,
151        )?;
152
153        Ok(())
154    }
155}