mistralrs_quant/blockwise_fp8/
mod.rs

1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
4use candle_nn::Linear;
5
6mod ops;
7
8#[cfg(feature = "cuda")]
9mod ffi;
10
11use crate::{
12    generate_isq, generate_isq_imatrix,
13    hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
14    AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
15    HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
16    QuantizedConfig, QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
17};
18
19#[derive(Debug)]
20pub struct BlockwiseFP8Linear {
21    weight: Tensor,
22    weight_scale_inv: Tensor,
23    bias: Option<Tensor>,
24    dequant_dtype: DType,
25    weight_block_size: Vec<usize>,
26}
27
28impl QuantMethod for BlockwiseFP8Linear {
29    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
30    where
31        Self: Sized,
32    {
33        match method {
34            QuantMethodConfig::Gguf { .. }
35            | QuantMethodConfig::Gptq { .. }
36            | QuantMethodConfig::Hqq { .. }
37            | QuantMethodConfig::Dummy
38            | QuantMethodConfig::Unquantized(_)
39            | QuantMethodConfig::Bnb { .. }
40            | QuantMethodConfig::FP8 { .. }
41            | QuantMethodConfig::Afq { .. } => unreachable!(),
42            QuantMethodConfig::BlockwiseFP8 {
43                weight,
44                weight_scale_inv,
45                bias,
46                dequant_dtype,
47                weight_block_size,
48            } => Ok(Self {
49                weight,
50                weight_scale_inv,
51                bias,
52                dequant_dtype,
53                weight_block_size,
54            }),
55        }
56    }
57    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
58        ops::fp8_blockwise_dequantize(
59            &self.weight,
60            &self.weight_scale_inv,
61            self.weight_block_size.to_vec(),
62            self.dequant_dtype,
63        )
64    }
65
66    fn forward(&self, x: &Tensor) -> Result<Tensor> {
67        // Dequantize matmul always.
68        // TODO: add a specific kernel?
69        let weight = self.dequantize_w()?;
70        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
71        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
72            weight,
73            self.bias.clone(),
74        )))?;
75        unquant.forward(x)
76    }
77
78    fn quantized_act_type(&self) -> Option<DType> {
79        None
80    }
81
82    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
83        candle_core::bail!("BlockwiseFP8Linear does not support add_delta_w")
84    }
85
86    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
87        (DType::F8E4M3, self.weight.device().clone())
88    }
89
90    fn apply_isq(
91        self: Arc<Self>,
92        dtype: Option<IsqType>,
93        device: Device,
94        n_quantized: &AtomicUsize,
95        imatrix_weight: Option<Vec<f32>>,
96        guard: QuantizeOntoGuard,
97    ) -> Result<Arc<dyn QuantMethod>> {
98        let weight = ops::fp8_blockwise_dequantize(
99            &self.weight,
100            &self.weight_scale_inv,
101            self.weight_block_size.to_vec(),
102            self.dequant_dtype,
103        )?;
104        match dtype {
105            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
106            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
107                let _acquired_quantize_guard = guard.acquire();
108                if imatrix_weight.is_some() {
109                    // TODO just warn?
110                    candle_core::bail!("HQQ does not support imatrix.");
111                }
112
113                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
114                let bits = match dtype.unwrap() {
115                    IsqType::HQQ8 => HqqBits::Eight,
116                    IsqType::HQQ4 => HqqBits::Four,
117                    // IsqType::HQQ3 => HqqBits::Three,
118                    // IsqType::HQQ2 => HqqBits::Two,
119                    // IsqType::HQQ1 => HqqBits::One,
120                    _ => unreachable!(),
121                };
122                let cfg = HqqConfig {
123                    bits,
124                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
125                    axis: HqqAxis::Zero,
126                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
127                    round_zeros: false,
128                    channel_wise: true,
129                };
130                let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
131                if let Some(bias) = &self.bias {
132                    let bias = bias
133                        .to_device(&device)?
134                        .to_dtype(res.dtype_and_device().0)?;
135                    Ok(Arc::new(res.with_bias(bias)))
136                } else {
137                    Ok(Arc::new(res))
138                }
139            }
140            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
141                let _acquired_quantize_guard = guard.acquire();
142                if imatrix_weight.is_some() {
143                    // TODO just warn?
144                    candle_core::bail!("AFQ does not support imatrix.");
145                }
146
147                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
148                let bits = match dtype.unwrap() {
149                    IsqType::AFQ8 => AfqBits::Eight,
150                    IsqType::AFQ6 => AfqBits::Six,
151                    IsqType::AFQ4 => AfqBits::Four,
152                    IsqType::AFQ3 => AfqBits::Three,
153                    IsqType::AFQ2 => AfqBits::Two,
154                    _ => unreachable!(),
155                };
156
157                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
158                    weight: weight.to_device(&device)?,
159                    bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
160                    bits,
161                    group_size: AfqGroupSize::default(),
162                })?))
163            }
164            Some(
165                IsqType::Q2K
166                | IsqType::Q3K
167                | IsqType::Q4K
168                | IsqType::Q4_0
169                | IsqType::Q4_1
170                | IsqType::Q5K
171                | IsqType::Q5_0
172                | IsqType::Q5_1
173                | IsqType::Q6K
174                | IsqType::Q8K
175                | IsqType::Q8_0
176                | IsqType::Q8_1,
177            ) => {
178                let dtype: GgmlDType = dtype.unwrap().try_into()?;
179                let res = if let Some(imatrix_weight) = imatrix_weight {
180                    generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
181                } else {
182                    generate_isq!(weight, device, dtype, n_quantized, guard)
183                };
184                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
185                    q_weight: res,
186                    b: self
187                        .bias
188                        .as_ref()
189                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
190                })?))
191            }
192            Some(IsqType::F8E4M3) => {
193                let _acquired_quantize_guard = guard.acquire();
194                if imatrix_weight.is_some() {
195                    // TODO just warn?
196                    candle_core::bail!("F8E4M3 does not support imatrix.");
197                }
198
199                let w = weight.to_device(&device)?;
200                let b = if let Some(b) = &self.bias {
201                    Some(b.to_device(&device)?)
202                } else {
203                    None
204                };
205                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
206                    lin: Linear::new(w, b),
207                    dtype: DType::F8E4M3,
208                })?))
209            }
210            None => {
211                let _acquired_quantize_guard = guard.acquire();
212                // Ignore imatrix altogether
213
214                let w = weight.to_device(&device)?;
215                let b = if let Some(b) = &self.bias {
216                    Some(b.to_device(&device)?)
217                } else {
218                    None
219                };
220                Ok(Arc::new(UnquantLinear::new(
221                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
222                )?))
223            }
224        }
225    }
226}
227
228// Serialization structure:
229//
230// -----------------------
231// UQFF version, u32, little endian
232// -----------------------
233// ISQ type (3 for fp8), u8, little endian
234// -----------------------
235// Whether bias data is included, u8 boolean
236// -----------------------
237// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
238// -----------------------
239// Dequant W scalar, f32, little endian
240// -----------------------
241// Dequant X scalar, f32, little endian
242// -----------------------
243// Quant scalar, f32, little endian
244// -----------------------
245// Quantization type, u32, little endian
246// -----------------------
247// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
248// -----------------------
249
250impl QuantizedSerde for BlockwiseFP8Linear {
251    fn isq_serde_supported(&self) -> bool {
252        false
253    }
254    fn name(&self) -> &'static str {
255        "blockwise-fp8-linear"
256    }
257}
258
259pub fn blockwise_fp8_linear_b(
260    in_dim: usize,
261    out_dim: usize,
262    config: &QuantizedConfig,
263    bias: bool,
264    hints: Shard,
265    vb: ShardedVarBuilder,
266) -> Result<Arc<dyn QuantMethod>> {
267    let QuantizedConfig::Fp8 { weight_block_size } = config else {
268        candle_core::bail!("Unexpected quantization config.")
269    };
270
271    // Handle the case where the layer is dummy (no tensors)
272    if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
273        let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
274        return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
275    }
276
277    if weight_block_size.len() != 2 {
278        candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
279    }
280    let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
281    let weight_scale_inv = vb.get_with_hints_dtype(
282        (
283            out_dim.div_ceil(weight_block_size[0]),
284            in_dim.div_ceil(weight_block_size[1]),
285        ),
286        "weight_scale_inv",
287        hints,
288        DType::F32,
289    )?;
290    let bias = if bias {
291        Some(vb.get((out_dim,), "bias")?)
292    } else {
293        None
294    };
295
296    Ok(Arc::new(BlockwiseFP8Linear {
297        weight,
298        weight_block_size: weight_block_size.clone(),
299        weight_scale_inv,
300        bias,
301        dequant_dtype: vb.dtype(),
302    }))
303}