mistralrs_quant/fp8/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor, D};
9use candle_nn::{Linear, Module};
10use quantize::QuantizationResult;
11
12mod quantize;
13
14use crate::{
15    cublaslt::{maybe_init_cublas_lt_wrapper, F8MatmulOutType, CUBLASLT_HANDLE},
16    utils::{
17        deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype,
18        UQFF_VERSION,
19    },
20    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
21};
22
23#[derive(Debug)]
24pub struct FP8Linear {
25    lin: Linear,
26    dequant_w_scale: Tensor,
27    dequant_x_scale: Tensor,
28    quant_scale: Tensor,
29    /// Quantized type
30    dtype: DType,
31}
32
33impl QuantMethod for FP8Linear {
34    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
35    where
36        Self: Sized,
37    {
38        match method {
39            QuantMethodConfig::Gguf { .. }
40            | QuantMethodConfig::Gptq { .. }
41            | QuantMethodConfig::Hqq { .. }
42            | QuantMethodConfig::Dummy
43            | QuantMethodConfig::Unquantized(_)
44            | QuantMethodConfig::Bnb { .. }
45            | QuantMethodConfig::BlockwiseFP8 { .. }
46            | QuantMethodConfig::Afq { .. } => unreachable!(),
47            QuantMethodConfig::FP8 { lin, dtype } => {
48                let QuantizationResult {
49                    qw,
50                    quantize_scale,
51                    dequantize_scale,
52                } = Self::quantize(lin.weight(), dtype)?;
53                Ok(Self {
54                    lin: Linear::new(qw, lin.bias().cloned()),
55                    dequant_x_scale: dequantize_scale.clone(), // This is probably wrong!
56                    dequant_w_scale: dequantize_scale,
57                    quant_scale: quantize_scale,
58                    dtype,
59                })
60            }
61        }
62    }
63    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
64        Ok(self.dequantize(DType::F32)?.weight().clone())
65    }
66
67    fn forward(&self, x: &Tensor) -> Result<Tensor> {
68        // Batch matrix multiplication
69        maybe_init_cublas_lt_wrapper(x.device().clone());
70
71        match *CUBLASLT_HANDLE.lock().unwrap() {
72            Some(handle) => {
73                let n_dims = x.dims().len();
74                if n_dims < 3 {
75                    candle_core::bail!(
76                        "FP8Linear `matmul` via cuBLASlt expects `x` to have at least 3 dimensions"
77                    );
78                }
79                // Set up target shape
80                let mut tgt_shape = x.dims().to_vec();
81                *tgt_shape.last_mut().unwrap() = self.lin.weight().dim(0)?;
82
83                // Flatten for correct dims
84                let mut x = x.flatten_to(D::Minus(3))?;
85
86                // Prepare the b tensor. If it is not quantized, quantize it
87                let mut dequant_x_scale = self.dequant_x_scale.clone();
88                if !matches!(x.dtype(), DType::F8E4M3) {
89                    let QuantizationResult {
90                        qw,
91                        quantize_scale: _,
92                        dequantize_scale,
93                    } = Self::quantize(&x, DType::F8E4M3)?;
94                    x = qw;
95                    dequant_x_scale = dequantize_scale;
96                }
97
98                // Handle bias
99                let beta = match self.lin.bias().is_some() {
100                    true => Some(1.0),
101                    false => None,
102                };
103
104                // Naming
105                let a = self.lin.weight().unsqueeze(0)?;
106                let b = x;
107
108                handle
109                    .batch_matmul_f8(
110                        &a,
111                        &b,
112                        &self.dequant_w_scale,
113                        &dequant_x_scale,
114                        &self.quant_scale,
115                        self.lin.bias(),
116                        None,
117                        beta,
118                        None,
119                        None,
120                        F8MatmulOutType::BF16, // Output in bf16 to avoid manual dequant
121                    )?
122                    .reshape(tgt_shape)
123            }
124            None => {
125                // Dequantize matmul
126                let dequant_x = x.clone();
127                let lin = self.dequantize(x.dtype())?;
128                lin.forward(&dequant_x)
129            }
130        }
131    }
132
133    fn quantized_act_type(&self) -> Option<DType> {
134        None
135    }
136
137    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
138        let dequant = self.dequantize(delta.dtype())?;
139        let new = Linear::new((dequant.weight() + delta)?, dequant.bias().cloned());
140        Ok(Arc::new(Self::new(QuantMethodConfig::FP8 {
141            lin: new,
142            dtype: self.dtype,
143        })?))
144    }
145
146    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
147        (DType::F8E4M3, self.lin.weight().device().clone())
148    }
149
150    fn apply_isq(
151        self: Arc<Self>,
152        _dtype: Option<IsqType>,
153        _device: Device,
154        _n_quantized: &AtomicUsize,
155        _imatrix_weight: Option<Vec<f32>>,
156        _guard: QuantizeOntoGuard,
157    ) -> Result<Arc<dyn QuantMethod>> {
158        todo!()
159    }
160}
161
162// Serialization structure:
163//
164// -----------------------
165// UQFF version, u32, little endian
166// -----------------------
167// ISQ type (3 for fp8), u8, little endian
168// -----------------------
169// Whether bias data is included, u8 boolean
170// -----------------------
171// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
172// -----------------------
173// Dequant W scalar, f32, little endian
174// -----------------------
175// Dequant X scalar, f32, little endian
176// -----------------------
177// Quant scalar, f32, little endian
178// -----------------------
179// Quantization type, u32, little endian
180// -----------------------
181// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
182// -----------------------
183
184impl QuantizedSerde for FP8Linear {
185    fn isq_serde_supported(&self) -> bool {
186        true
187    }
188    fn name(&self) -> &'static str {
189        "fp8-linear"
190    }
191    fn serialize(&self) -> Result<Cow<[u8]>> {
192        self.serialize_with_bias(self.lin.bias().cloned())
193    }
194    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
195        let mut buffer = Vec::new();
196
197        // Version is always first!
198        buffer.extend(&UQFF_VERSION.to_le_bytes());
199
200        // ISQ type for fp8 is 3
201        buffer.push(QuantizedSerdeType::Fp8 as u8);
202
203        // Has bias
204        buffer.push(bias.is_some() as u8);
205
206        // Weight
207        serialize_tensor(&mut buffer, self.lin.weight())?;
208
209        // Dequant a scale
210        buffer.extend(self.dequant_w_scale.to_scalar::<f32>()?.to_le_bytes());
211        // Dequant b scale
212        buffer.extend(self.dequant_x_scale.to_scalar::<f32>()?.to_le_bytes());
213        // Quant scale
214        buffer.extend(self.quant_scale.to_scalar::<f32>()?.to_le_bytes());
215
216        // DType
217        write_dtype(self.dtype, &mut buffer);
218
219        if let Some(bias) = &bias {
220            // Bias
221            serialize_tensor(&mut buffer, bias)?;
222        }
223
224        Ok(Cow::from(buffer))
225    }
226
227    fn deserialize(
228        data: Cow<[u8]>,
229        device: &Device,
230        _comm: &Arc<crate::Comm>,
231        guard: QuantizeOntoGuard,
232    ) -> Result<Arc<dyn QuantMethod>>
233    where
234        Self: Sized,
235    {
236        let mut buffer = Cursor::new(data.to_vec());
237
238        let version = buffer.read_u32::<LittleEndian>()?;
239        if let Err(e) = version_is_compatible(version) {
240            return Err(candle_core::Error::wrap(e));
241        }
242
243        let isq_type = buffer.read_u8()? as usize;
244        if isq_type != QuantizedSerdeType::Fp8 as usize {
245            candle_core::bail!(
246                "ISQ type ({isq_type}) doesn't match expected type {}",
247                QuantizedSerdeType::Fp8 as usize
248            );
249        }
250
251        let has_bias = buffer.read_u8()? != 0;
252
253        let w = deserialize_tensor(&mut buffer, device)?;
254
255        let _acquired_load_guard = guard.acquire();
256        let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
257        let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
258        let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
259
260        // DType
261        let dtype = read_dtype(&mut buffer)?;
262
263        let b = if has_bias {
264            Some(deserialize_tensor(&mut buffer, device)?)
265        } else {
266            None
267        };
268
269        Ok(Arc::new(Self {
270            lin: Linear::new(w, b),
271            dequant_w_scale,
272            dequant_x_scale,
273            quant_scale,
274            dtype,
275        }))
276    }
277    fn deserialize_ext_bias(
278        data: Cow<[u8]>,
279        device: &Device,
280        guard: QuantizeOntoGuard,
281    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
282    where
283        Self: Sized,
284    {
285        let mut buffer = Cursor::new(data.to_vec());
286
287        let version = buffer.read_u32::<LittleEndian>()?;
288        if let Err(e) = version_is_compatible(version) {
289            return Err(candle_core::Error::wrap(e));
290        }
291
292        let isq_type = buffer.read_u8()? as usize;
293        if isq_type != QuantizedSerdeType::Fp8 as usize {
294            candle_core::bail!(
295                "ISQ type ({isq_type}) doesn't match expected type {}",
296                QuantizedSerdeType::Fp8 as usize
297            );
298        }
299
300        let has_bias = buffer.read_u8()? != 0;
301
302        let _acquired_load_guard = guard.acquire();
303        let w = deserialize_tensor(&mut buffer, device)?;
304
305        let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
306        let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
307        let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
308
309        // DType
310        let dtype = read_dtype(&mut buffer)?;
311
312        let b = if has_bias {
313            Some(deserialize_tensor(&mut buffer, device)?)
314        } else {
315            None
316        };
317
318        Ok((
319            Arc::new(Self {
320                lin: Linear::new(w, None),
321                dequant_w_scale,
322                dequant_x_scale,
323                quant_scale,
324                dtype,
325            }),
326            b,
327        ))
328    }
329}