mistralrs_quant/fp8/
quantize.rs1use candle_core::{DType, Result, Tensor};
2use candle_nn::Linear;
3use float8::F8E4M3;
4
5use super::FP8Linear;
6
7pub(super) struct QuantizationResult {
8 pub(super) qw: Tensor,
10 pub(super) quantize_scale: Tensor,
15 pub(super) dequantize_scale: Tensor,
20}
21
22impl FP8Linear {
23 pub(super) fn quantize(data: &Tensor, dtype: DType) -> Result<QuantizationResult> {
24 let data = data.to_dtype(DType::BF16)?;
25 let mut absmax = data.abs()?;
26 while !absmax.dims().is_empty() {
27 absmax = absmax.max(0)?;
28 }
29
30 let max_v = F8E4M3::MAX.to_f64();
31 let scale = (max_v / absmax)?
32 .clamp(F8E4M3::MIN.to_f32(), F8E4M3::MAX.to_f32())?
33 .to_dtype(DType::F32)?;
34 let to_cast = data.broadcast_mul(&scale.to_dtype(data.dtype())?)?;
35 let qw = if dtype == DType::F8E4M3 {
36 crate::scalar_fp8::ops::dtype_to_fp8(&to_cast)?
37 } else {
38 to_cast.to_dtype(dtype)?
39 };
40 Ok(QuantizationResult {
41 qw,
42 quantize_scale: scale.clone(),
43 dequantize_scale: scale.recip()?,
44 })
45 }
46
47 pub(super) fn dequantize(&self, dtype: DType) -> Result<Linear> {
48 let dequant_w = self
49 .lin
50 .weight()
51 .to_dtype(dtype)?
52 .broadcast_mul(&self.dequant_w_scale.to_dtype(dtype)?)?;
53 Ok(Linear::new(dequant_w, self.lin.bias().cloned()))
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 #[cfg(not(feature = "metal"))]
60 use candle_core::{
61 quantized::{GgmlDType, QTensor},
62 DType, Device, Result, Tensor,
63 };
64
65 #[cfg(not(feature = "metal"))]
66 use crate::fp8::FP8Linear;
67
68 #[cfg(not(feature = "metal"))]
69 use super::QuantizationResult;
70
71 #[test]
72 #[cfg(not(feature = "metal"))]
73 fn test_roundtrip_f8e4m3() -> Result<()> {
74 let dev = Device::cuda_if_available(0)?;
75
76 let data = Tensor::rand(0f32, 1f32, (32, 32), &dev)?;
77
78 let QuantizationResult {
79 qw,
80 quantize_scale: _,
81 dequantize_scale,
82 } = FP8Linear::quantize(&data, DType::F8E4M3)?;
83
84 let dequant = crate::scalar_fp8::ops::fp8_to_dtype(&qw, DType::F32)?
85 .broadcast_mul(&dequantize_scale)?;
86
87 let diff1 = (&data - dequant)?.abs()?.mean_all()?;
88
89 println!("{diff1}");
90
91 let q8_0 = QTensor::quantize(&data, GgmlDType::Q8_0)?.dequantize(&dev)?;
92 let diff2 = (&data - q8_0)?.abs()?.mean_all()?;
93
94 println!("{diff2}");
95 Ok(())
96 }
97
98 #[test]
99 #[cfg(feature = "cuda")]
100 fn test_cublaslt_matmul() -> Result<()> {
101 use crate::cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER};
102 let dev = Device::new_cuda(0)?;
103
104 let w = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?;
105 let mut x = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?;
106
107 maybe_init_cublas_lt_wrapper(x.device().clone());
109
110 let handle = CUBLASLT_CONTROLLER.get().unwrap();
111
112 let QuantizationResult {
113 qw,
114 quantize_scale: quant_scale,
115 dequantize_scale: dequant_a_scale,
116 } = FP8Linear::quantize(&w, DType::F8E4M3)?;
117
118 let mut dequant_b_scale = dequant_a_scale.clone();
119 if !matches!(x.dtype(), DType::F8E4M3) {
120 let QuantizationResult {
121 qw,
122 quantize_scale: _,
123 dequantize_scale,
124 } = FP8Linear::quantize(&x, DType::F8E4M3)?;
125 x = qw;
126 dequant_b_scale = dequantize_scale;
127 }
128
129 let a = qw;
130 let b = x;
131
132 let _res = handle.batch_matmul_f8(
134 &a,
135 &b,
136 &dequant_a_scale,
137 &dequant_b_scale,
138 &quant_scale,
139 None,
140 None,
141 None,
142 None,
143 None,
144 )?;
145
146 Ok(())
147 }
148}