mistralrs_quant/fp8/
quantize.rs1use candle_core::{DType, Result, Tensor};
2use candle_nn::Linear;
3use float8::F8E4M3;
4use half::bf16;
5
6use super::FP8Linear;
7
8pub(super) struct QuantizationResult {
9 pub(super) qw: Tensor,
11 pub(super) quantize_scale: Tensor,
16 pub(super) dequantize_scale: Tensor,
21}
22
23impl FP8Linear {
24 pub(super) fn quantize(data: &Tensor, dtype: DType) -> Result<QuantizationResult> {
25 let data = data.to_dtype(DType::BF16)?;
26 let mut absmax = data.abs()?;
27 while !absmax.dims().is_empty() {
28 absmax = absmax.max(0)?;
29 }
30
31 let max_v = F8E4M3::MAX.to_f64();
32 let scale = (max_v / absmax)?
33 .clamp(F8E4M3::MIN.to_f32(), F8E4M3::MAX.to_f32())?
34 .to_dtype(DType::F32)?;
35 let to_cast = data.broadcast_mul(&scale.to_dtype(data.dtype())?)?;
36 let qw = if data.device().is_metal() {
37 let transmute_data = to_cast
39 .flatten_all()?
40 .to_vec1::<bf16>()?
41 .into_iter()
42 .map(|x| x.to_f64_const().to_bits() as i64)
43 .collect::<Vec<_>>();
44 Tensor::from_vec(transmute_data, data.shape(), data.device())?.to_dtype(dtype)?
45 } else {
46 to_cast.to_dtype(dtype)?
47 };
48 Ok(QuantizationResult {
49 qw,
50 quantize_scale: scale.clone(),
51 dequantize_scale: scale.recip()?,
52 })
53 }
54
55 pub(super) fn dequantize(&self, dtype: DType) -> Result<Linear> {
56 let dequant_w = self
57 .lin
58 .weight()
59 .to_dtype(dtype)?
60 .broadcast_mul(&self.dequant_w_scale.to_dtype(dtype)?)?;
61 Ok(Linear::new(dequant_w, self.lin.bias().cloned()))
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use candle_core::{
68 quantized::{GgmlDType, QTensor},
69 DType, Device, Result, Tensor,
70 };
71
72 use crate::fp8::FP8Linear;
73
74 use super::QuantizationResult;
75
76 #[test]
77 fn test_roundtrip_f8e4m3() -> Result<()> {
78 #[cfg(not(feature = "metal"))]
79 let dev = Device::cuda_if_available(0)?;
80 #[cfg(feature = "metal")]
81 let dev = Device::new_metal(0)?;
82
83 let data = Tensor::rand(0f32, 1f32, (32, 32), &dev)?;
84
85 let QuantizationResult {
86 qw,
87 quantize_scale: _,
88 dequantize_scale,
89 } = FP8Linear::quantize(&data, DType::F8E4M3)?;
90
91 let dequant = qw.to_dtype(DType::F32)?.broadcast_mul(&dequantize_scale)?;
92
93 let diff1 = (&data - dequant)?.abs()?.mean_all()?;
94
95 println!("{diff1}");
96
97 let q8_0 = QTensor::quantize(&data, GgmlDType::Q8_0)?.dequantize(&dev)?;
98 let diff2 = (&data - q8_0)?.abs()?.mean_all()?;
99
100 println!("{diff2}");
101 Ok(())
102 }
103
104 #[test]
105 #[cfg(feature = "cuda")]
106 fn test_cublaslt_matmul() -> Result<()> {
107 use crate::cublaslt::{maybe_init_cublas_lt_wrapper, F8MatmulOutType, CUBLASLT_HANDLE};
108 let dev = Device::new_cuda(0)?;
109
110 let w = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?;
111 let mut x = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?;
112
113 maybe_init_cublas_lt_wrapper(x.device().clone());
115
116 let handle = CUBLASLT_HANDLE.lock().unwrap().unwrap();
117
118 let QuantizationResult {
119 qw,
120 quantize_scale: quant_scale,
121 dequantize_scale: dequant_a_scale,
122 } = FP8Linear::quantize(&w, DType::F8E4M3)?;
123
124 let mut dequant_b_scale = dequant_a_scale.clone();
125 if !matches!(x.dtype(), DType::F8E4M3) {
126 let QuantizationResult {
127 qw,
128 quantize_scale: _,
129 dequantize_scale,
130 } = FP8Linear::quantize(&x, DType::F8E4M3)?;
131 x = qw;
132 dequant_b_scale = dequantize_scale;
133 }
134
135 let a = qw;
136 let b = x;
137
138 let _res = handle.batch_matmul_f8(
140 &a,
141 &b,
142 &dequant_a_scale,
143 &dequant_b_scale,
144 &quant_scale,
145 None,
146 None,
147 None,
148 None,
149 None,
150 F8MatmulOutType::BF16,
151 )?;
152
153 Ok(())
154 }
155}