mistralrs_quant/blockwise_fp8/
mod.rs

1use std::{
2    num::NonZeroUsize,
3    sync::{atomic::AtomicUsize, Arc},
4};
5
6use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
7use candle_nn::Linear;
8
9mod ops;
10
11#[cfg(feature = "cuda")]
12mod ffi;
13
14use crate::{
15    generate_isq, generate_isq_imatrix,
16    hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
17    DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits, HqqConfig, HqqLayer, IsqType, QuantMethod,
18    QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, Shard,
19    ShardedVarBuilder, UnquantLinear,
20};
21
22#[derive(Debug)]
23pub struct BlockwiseFP8Linear {
24    weight: Tensor,
25    weight_scale_inv: Tensor,
26    bias: Option<Tensor>,
27    dequant_dtype: DType,
28    weight_block_size: Vec<usize>,
29}
30
31impl QuantMethod for BlockwiseFP8Linear {
32    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
33    where
34        Self: Sized,
35    {
36        match method {
37            QuantMethodConfig::Gguf { .. }
38            | QuantMethodConfig::Gptq { .. }
39            | QuantMethodConfig::Hqq { .. }
40            | QuantMethodConfig::Dummy
41            | QuantMethodConfig::Unquantized(_)
42            | QuantMethodConfig::Bnb { .. }
43            | QuantMethodConfig::FP8 { .. } => unreachable!(),
44            QuantMethodConfig::BlockwiseFP8 {
45                weight,
46                weight_scale_inv,
47                bias,
48                dequant_dtype,
49                weight_block_size,
50            } => Ok(Self {
51                weight,
52                weight_scale_inv,
53                bias,
54                dequant_dtype,
55                weight_block_size,
56            }),
57        }
58    }
59    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
60        ops::fp8_blockwise_dequantize(
61            &self.weight,
62            &self.weight_scale_inv,
63            self.weight_block_size.to_vec(),
64            self.dequant_dtype,
65        )
66    }
67
68    fn forward(&self, x: &Tensor) -> Result<Tensor> {
69        // Dequantize matmul always.
70        // TODO: add a specific kernel?
71        let weight = self.dequantize_w()?;
72        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
73        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
74            weight,
75            self.bias.clone(),
76        )))?;
77        unquant.forward(x)
78    }
79
80    fn quantized_act_type(&self) -> Option<DType> {
81        None
82    }
83
84    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
85        candle_core::bail!("BlockwiseFP8Linear does not support add_delta_w")
86    }
87
88    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
89        (DType::F8E4M3, self.weight.device().clone())
90    }
91
92    fn apply_isq(
93        self: Arc<Self>,
94        dtype: Option<IsqType>,
95        device: Device,
96        n_quantized: &AtomicUsize,
97        imatrix_weight: Option<Vec<f32>>,
98        guard: QuantizeOntoGuard,
99    ) -> Result<Arc<dyn QuantMethod>> {
100        let weight = ops::fp8_blockwise_dequantize(
101            &self.weight,
102            &self.weight_scale_inv,
103            self.weight_block_size.to_vec(),
104            self.dequant_dtype,
105        )?;
106        match dtype {
107            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
108            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
109                let _acquired_quantize_guard = guard.acquire();
110                if imatrix_weight.is_some() {
111                    // TODO just warn?
112                    candle_core::bail!("HQQ does not support imatrix.");
113                }
114
115                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
116                let bits = match dtype.unwrap() {
117                    IsqType::HQQ8 => HqqBits::Eight,
118                    IsqType::HQQ4 => HqqBits::Four,
119                    // IsqType::HQQ3 => HqqBits::Three,
120                    // IsqType::HQQ2 => HqqBits::Two,
121                    // IsqType::HQQ1 => HqqBits::One,
122                    _ => unreachable!(),
123                };
124                let cfg = HqqConfig {
125                    bits,
126                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
127                    axis: HqqAxis::Zero,
128                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
129                    round_zeros: false,
130                    channel_wise: true,
131                };
132                let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
133                if let Some(bias) = &self.bias {
134                    let bias = bias
135                        .to_device(&device)?
136                        .to_dtype(res.dtype_and_device().0)?;
137                    Ok(Arc::new(res.with_bias(bias)))
138                } else {
139                    Ok(Arc::new(res))
140                }
141            }
142            Some(
143                IsqType::Q2K
144                | IsqType::Q3K
145                | IsqType::Q4K
146                | IsqType::Q4_0
147                | IsqType::Q4_1
148                | IsqType::Q5K
149                | IsqType::Q5_0
150                | IsqType::Q5_1
151                | IsqType::Q6K
152                | IsqType::Q8K
153                | IsqType::Q8_0
154                | IsqType::Q8_1,
155            ) => {
156                let dtype: GgmlDType = dtype.unwrap().try_into()?;
157                let res = if let Some(imatrix_weight) = imatrix_weight {
158                    generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
159                } else {
160                    generate_isq!(weight, device, dtype, n_quantized, guard)
161                };
162                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
163                    q_weight: res,
164                    b: self
165                        .bias
166                        .as_ref()
167                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
168                })?))
169            }
170            Some(IsqType::F8E4M3) => {
171                let _acquired_quantize_guard = guard.acquire();
172                if imatrix_weight.is_some() {
173                    // TODO just warn?
174                    candle_core::bail!("F8E4M3 does not support imatrix.");
175                }
176
177                let w = weight.to_device(&device)?;
178                let b = if let Some(b) = &self.bias {
179                    Some(b.to_device(&device)?)
180                } else {
181                    None
182                };
183                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
184                    lin: Linear::new(w, b),
185                    dtype: DType::F8E4M3,
186                })?))
187            }
188            None => {
189                let _acquired_quantize_guard = guard.acquire();
190                // Ignore imatrix altogether
191
192                let w = weight.to_device(&device)?;
193                let b = if let Some(b) = &self.bias {
194                    Some(b.to_device(&device)?)
195                } else {
196                    None
197                };
198                Ok(Arc::new(UnquantLinear::new(
199                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
200                )?))
201            }
202        }
203    }
204
205    fn get_max_isq_cpu_threads(&self, dtype: IsqType) -> Option<NonZeroUsize> {
206        match dtype {
207            IsqType::F8E4M3 => None,
208            IsqType::Q2K
209            | IsqType::Q3K
210            | IsqType::Q4K
211            | IsqType::Q4_0
212            | IsqType::Q4_1
213            | IsqType::Q5K
214            | IsqType::Q5_0
215            | IsqType::Q5_1
216            | IsqType::Q6K
217            | IsqType::Q8K
218            | IsqType::Q8_0
219            | IsqType::Q8_1
220            | IsqType::HQQ4
221            | IsqType::HQQ8 => None,
222        }
223    }
224}
225
226// Serialization structure:
227//
228// -----------------------
229// UQFF version, u32, little endian
230// -----------------------
231// ISQ type (3 for fp8), u8, little endian
232// -----------------------
233// Whether bias data is included, u8 boolean
234// -----------------------
235// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
236// -----------------------
237// Dequant W scalar, f32, little endian
238// -----------------------
239// Dequant X scalar, f32, little endian
240// -----------------------
241// Quant scalar, f32, little endian
242// -----------------------
243// Quantization type, u32, little endian
244// -----------------------
245// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
246// -----------------------
247
248impl QuantizedSerde for BlockwiseFP8Linear {
249    fn isq_serde_supported(&self) -> bool {
250        false
251    }
252    fn name(&self) -> &'static str {
253        "blockwise-fp8-linear"
254    }
255}
256
257pub fn blockwise_fp8_linear_b(
258    in_dim: usize,
259    out_dim: usize,
260    config: &QuantizedConfig,
261    bias: bool,
262    hints: Shard,
263    vb: ShardedVarBuilder,
264) -> Result<Arc<dyn QuantMethod>> {
265    // Handle the case where the layer is dummy (no tensors)
266    if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
267        let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
268        return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
269    }
270
271    let weight_block_size = config
272        .weight_block_size
273        .as_ref()
274        .expect("Blockwise FP8 requires weight_block_size in config");
275    if weight_block_size.len() != 2 {
276        candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
277    }
278    let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
279    let weight_scale_inv = vb.get_with_hints_dtype(
280        (
281            out_dim.div_ceil(weight_block_size[0]),
282            in_dim.div_ceil(weight_block_size[1]),
283        ),
284        "weight_scale_inv",
285        hints,
286        DType::F32,
287    )?;
288    let bias = if bias {
289        Some(vb.get((out_dim,), "bias")?)
290    } else {
291        None
292    };
293
294    Ok(Arc::new(BlockwiseFP8Linear {
295        weight,
296        weight_block_size: weight_block_size.clone(),
297        weight_scale_inv,
298        bias,
299        dequant_dtype: vb.dtype(),
300    }))
301}