mistralrs_quant/fp8/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor, D};
9use candle_nn::{Linear, Module};
10use quantize::QuantizationResult;
11
12mod quantize;
13
14use crate::{
15    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
16    utils::{
17        deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype,
18        UQFF_VERSION,
19    },
20    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
21};
22
23#[derive(Debug)]
24pub struct FP8Linear {
25    lin: Linear,
26    dequant_w_scale: Tensor,
27    dequant_x_scale: Tensor,
28    quant_scale: Tensor,
29    /// Quantized type
30    dtype: DType,
31}
32
33impl QuantMethod for FP8Linear {
34    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
35    where
36        Self: Sized,
37    {
38        match method {
39            QuantMethodConfig::Gguf { .. }
40            | QuantMethodConfig::GptqAwq { .. }
41            | QuantMethodConfig::Hqq { .. }
42            | QuantMethodConfig::Dummy
43            | QuantMethodConfig::Unquantized(_)
44            | QuantMethodConfig::Bnb { .. }
45            | QuantMethodConfig::BlockwiseFP8 { .. }
46            | QuantMethodConfig::PerTensorFP8 { .. }
47            | QuantMethodConfig::Afq { .. }
48            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
49            QuantMethodConfig::FP8 { lin, dtype } => {
50                let QuantizationResult {
51                    qw,
52                    quantize_scale,
53                    dequantize_scale,
54                } = Self::quantize(lin.weight(), dtype)?;
55                Ok(Self {
56                    lin: Linear::new(qw, lin.bias().cloned()),
57                    dequant_x_scale: dequantize_scale.clone(), // This is probably wrong!
58                    dequant_w_scale: dequantize_scale,
59                    quant_scale: quantize_scale,
60                    dtype,
61                })
62            }
63        }
64    }
65    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
66        Ok(self.dequantize(DType::F32)?.weight().clone())
67    }
68
69    fn forward(&self, x: &Tensor) -> Result<Tensor> {
70        // Batch matrix multiplication
71        maybe_init_cublas_lt_wrapper(x.device().clone());
72
73        match CUBLASLT_CONTROLLER.get_for_device(x.device()) {
74            Some(handle) => {
75                let n_dims = x.dims().len();
76                if n_dims < 3 {
77                    candle_core::bail!(
78                        "FP8Linear `matmul` via cuBLASlt expects `x` to have at least 3 dimensions"
79                    );
80                }
81                // Set up target shape
82                let mut tgt_shape = x.dims().to_vec();
83                *tgt_shape.last_mut().unwrap() = self.lin.weight().dim(0)?;
84
85                // Flatten for correct dims
86                let mut x = x.flatten_to(D::Minus(3))?;
87
88                // Prepare the b tensor. If it is not quantized, quantize it
89                let mut dequant_x_scale = self.dequant_x_scale.clone();
90                if !matches!(x.dtype(), DType::F8E4M3) {
91                    let QuantizationResult {
92                        qw,
93                        quantize_scale: _,
94                        dequantize_scale,
95                    } = Self::quantize(&x, DType::F8E4M3)?;
96                    x = qw;
97                    dequant_x_scale = dequantize_scale;
98                }
99
100                // Handle bias
101                let beta = match self.lin.bias().is_some() {
102                    true => Some(1.0),
103                    false => None,
104                };
105
106                // Naming
107                let a = self.lin.weight().unsqueeze(0)?;
108                let b = x;
109
110                handle
111                    .batch_matmul_f8(
112                        &a,
113                        &b,
114                        &self.dequant_w_scale,
115                        &dequant_x_scale,
116                        &self.quant_scale,
117                        self.lin.bias(),
118                        None,
119                        beta,
120                        None,
121                        None,
122                    )?
123                    .reshape(tgt_shape)
124            }
125            None => {
126                // Dequantize matmul
127                let dequant_x = x.clone();
128                let lin = self.dequantize(x.dtype())?;
129                lin.forward(&dequant_x)
130            }
131        }
132    }
133
134    fn quantized_act_type(&self) -> Option<DType> {
135        None
136    }
137
138    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139        let dequant = self.dequantize(delta.dtype())?;
140        let new = Linear::new((dequant.weight() + delta)?, dequant.bias().cloned());
141        Ok(Arc::new(Self::new(QuantMethodConfig::FP8 {
142            lin: new,
143            dtype: self.dtype,
144        })?))
145    }
146
147    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
148        (DType::F8E4M3, self.lin.weight().device().clone())
149    }
150
151    fn apply_isq(
152        self: Arc<Self>,
153        _dtype: Option<IsqType>,
154        _device: Device,
155        _n_quantized: &AtomicUsize,
156        _imatrix_weight: Option<Vec<f32>>,
157        _guard: QuantizeOntoGuard,
158    ) -> Result<Arc<dyn QuantMethod>> {
159        todo!()
160    }
161}
162
163// Serialization structure:
164//
165// -----------------------
166// UQFF version, u32, little endian
167// -----------------------
168// ISQ type (3 for fp8), u8, little endian
169// -----------------------
170// Whether bias data is included, u8 boolean
171// -----------------------
172// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
173// -----------------------
174// Dequant W scalar, f32, little endian
175// -----------------------
176// Dequant X scalar, f32, little endian
177// -----------------------
178// Quant scalar, f32, little endian
179// -----------------------
180// Quantization type, u32, little endian
181// -----------------------
182// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
183// -----------------------
184
185impl QuantizedSerde for FP8Linear {
186    fn isq_serde_supported(&self) -> bool {
187        true
188    }
189    fn name(&self) -> &'static str {
190        "fp8-linear"
191    }
192    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
193        self.serialize_with_bias(self.lin.bias().cloned())
194    }
195    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
196        let mut buffer = Vec::new();
197
198        // Version is always first!
199        buffer.extend(&UQFF_VERSION.to_le_bytes());
200
201        // ISQ type for fp8 is 3
202        buffer.push(QuantizedSerdeType::Fp8 as u8);
203
204        // Has bias
205        buffer.push(bias.is_some() as u8);
206
207        // Weight
208        serialize_tensor(&mut buffer, self.lin.weight())?;
209
210        // Dequant a scale
211        buffer.extend(self.dequant_w_scale.to_scalar::<f32>()?.to_le_bytes());
212        // Dequant b scale
213        buffer.extend(self.dequant_x_scale.to_scalar::<f32>()?.to_le_bytes());
214        // Quant scale
215        buffer.extend(self.quant_scale.to_scalar::<f32>()?.to_le_bytes());
216
217        // DType
218        write_dtype(self.dtype, &mut buffer);
219
220        if let Some(bias) = &bias {
221            // Bias
222            serialize_tensor(&mut buffer, bias)?;
223        }
224
225        Ok(Cow::from(buffer))
226    }
227
228    fn deserialize(
229        data: Cow<[u8]>,
230        device: &Device,
231        _comm: &Arc<crate::Comm>,
232        guard: QuantizeOntoGuard,
233    ) -> Result<Arc<dyn QuantMethod>>
234    where
235        Self: Sized,
236    {
237        let mut buffer = Cursor::new(data.to_vec());
238
239        let version = buffer.read_u32::<LittleEndian>()?;
240        if let Err(e) = version_is_compatible(version) {
241            return Err(candle_core::Error::wrap(e));
242        }
243
244        let isq_type = buffer.read_u8()? as usize;
245        if isq_type != QuantizedSerdeType::Fp8 as usize {
246            candle_core::bail!(
247                "ISQ type ({isq_type}) doesn't match expected type {}",
248                QuantizedSerdeType::Fp8 as usize
249            );
250        }
251
252        let has_bias = buffer.read_u8()? != 0;
253
254        let w = deserialize_tensor(&mut buffer, device)?;
255
256        let _acquired_load_guard = guard.acquire(device);
257        let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
258        let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
259        let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
260
261        // DType
262        let dtype = read_dtype(&mut buffer)?;
263
264        let b = if has_bias {
265            Some(deserialize_tensor(&mut buffer, device)?)
266        } else {
267            None
268        };
269
270        Ok(Arc::new(Self {
271            lin: Linear::new(w, b),
272            dequant_w_scale,
273            dequant_x_scale,
274            quant_scale,
275            dtype,
276        }))
277    }
278    fn deserialize_ext_bias(
279        data: Cow<[u8]>,
280        device: &Device,
281        guard: QuantizeOntoGuard,
282    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
283    where
284        Self: Sized,
285    {
286        let mut buffer = Cursor::new(data.to_vec());
287
288        let version = buffer.read_u32::<LittleEndian>()?;
289        if let Err(e) = version_is_compatible(version) {
290            return Err(candle_core::Error::wrap(e));
291        }
292
293        let isq_type = buffer.read_u8()? as usize;
294        if isq_type != QuantizedSerdeType::Fp8 as usize {
295            candle_core::bail!(
296                "ISQ type ({isq_type}) doesn't match expected type {}",
297                QuantizedSerdeType::Fp8 as usize
298            );
299        }
300
301        let has_bias = buffer.read_u8()? != 0;
302
303        let _acquired_load_guard = guard.acquire(device);
304        let w = deserialize_tensor(&mut buffer, device)?;
305
306        let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
307        let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
308        let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
309
310        // DType
311        let dtype = read_dtype(&mut buffer)?;
312
313        let b = if has_bias {
314            Some(deserialize_tensor(&mut buffer, device)?)
315        } else {
316            None
317        };
318
319        Ok((
320            Arc::new(Self {
321                lin: Linear::new(w, None),
322                dequant_w_scale,
323                dequant_x_scale,
324                quant_scale,
325                dtype,
326            }),
327            b,
328        ))
329    }
330}