mistralrs_quant/blockwise_fp8/
mod.rs

1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
4use candle_nn::Linear;
5
6mod ops;
7pub use ops::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
8#[cfg(feature = "cuda")]
9#[allow(unused_imports)]
10pub(crate) use ops::{fp8_blockwise_matmul, fp8_indexed_moe_gemm};
11
12#[cfg(feature = "cuda")]
13mod ffi;
14
15use crate::{
16    generate_isq, generate_isq_imatrix,
17    hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
18    AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
19    HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
20    QuantizedConfig, QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
21};
22
23#[derive(Debug)]
24pub struct BlockwiseFP8Linear {
25    weight: Tensor,
26    weight_scale_inv: Tensor,
27    bias: Option<Tensor>,
28    dequant_dtype: DType,
29    weight_block_size: Vec<usize>,
30}
31
32impl QuantMethod for BlockwiseFP8Linear {
33    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
34    where
35        Self: Sized,
36    {
37        match method {
38            QuantMethodConfig::Gguf { .. }
39            | QuantMethodConfig::GptqAwq { .. }
40            | QuantMethodConfig::Hqq { .. }
41            | QuantMethodConfig::Dummy
42            | QuantMethodConfig::Unquantized(_)
43            | QuantMethodConfig::Bnb { .. }
44            | QuantMethodConfig::FP8 { .. }
45            | QuantMethodConfig::Afq { .. }
46            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
47            QuantMethodConfig::BlockwiseFP8 {
48                weight,
49                weight_scale_inv,
50                bias,
51                dequant_dtype,
52                weight_block_size,
53            } => Ok(Self {
54                weight,
55                weight_scale_inv,
56                bias,
57                dequant_dtype,
58                weight_block_size,
59            }),
60        }
61    }
62    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
63        ops::fp8_blockwise_dequantize(
64            &self.weight,
65            &self.weight_scale_inv,
66            self.weight_block_size.to_vec(),
67            self.dequant_dtype,
68        )
69    }
70
71    fn forward(&self, x: &Tensor) -> Result<Tensor> {
72        // Try to use native FP8 GEMM kernel on CUDA
73        #[cfg(feature = "cuda")]
74        {
75            if matches!(x.device(), candle_core::Device::Cuda(_))
76                && ffi::HAVE_BLOCKWISE_GEMM_KERNELS
77            {
78                // Handle batched inputs by flattening to 2D
79                let orig_dims = x.dims().to_vec();
80                let x_2d = if orig_dims.len() > 2 {
81                    // Flatten all but last dim: [batch, seq, features] -> [batch*seq, features]
82                    let features = orig_dims[orig_dims.len() - 1];
83                    let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
84                    x.reshape((batch_size, features))?
85                } else {
86                    x.clone()
87                };
88
89                // Use native FP8 GEMM kernel
90                let result = ops::fp8_blockwise_matmul(
91                    &x_2d,
92                    &self.weight,
93                    &self.weight_scale_inv,
94                    &self.weight_block_size,
95                )?;
96
97                // Reshape back to original batch dimensions
98                let result = if orig_dims.len() > 2 {
99                    let out_features = result.dim(1)?;
100                    let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
101                    new_dims.push(out_features);
102                    result.reshape(new_dims)?
103                } else {
104                    result
105                };
106
107                // Apply bias if present
108                if let Some(ref bias) = self.bias {
109                    return result.broadcast_add(bias);
110                }
111                return Ok(result);
112            }
113        }
114
115        // Fallback: dequantize and use unquantized matmul
116        let weight = self.dequantize_w()?;
117        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
118        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
119            weight,
120            self.bias.clone(),
121        )))?;
122        unquant.forward(x)
123    }
124
125    /// Compute matmul of `self` and `a`. `self` should contain the weights.
126    ///
127    /// If `a` is (n_tokens, 1, cols), `self` weights are (n_experts, rows, cols),
128    /// then the indices are (n_tokens, n_experts_per_tok).
129    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
130        // Try to use native FP8 indexed MoE GEMM kernel on CUDA
131        #[cfg(feature = "cuda")]
132        {
133            if matches!(x.device(), candle_core::Device::Cuda(_))
134                && ffi::HAVE_BLOCKWISE_GEMM_KERNELS
135            {
136                // Use native FP8 indexed MoE GEMM kernel (expects U32 indices)
137                let result = ops::fp8_indexed_moe_gemm(
138                    x,
139                    &self.weight,
140                    &self.weight_scale_inv,
141                    indices,
142                    &self.weight_block_size,
143                )?;
144                // Apply bias if present (broadcast over tokens and topk)
145                if let Some(ref bias) = self.bias {
146                    return result.broadcast_add(bias);
147                }
148                return Ok(result);
149            }
150        }
151
152        // Fallback: dequantize weights and compute manually
153        let weight = self.dequantize_w()?;
154
155        // Expected shapes:
156        // - x: (n_tokens, 1, hidden_dim) or (n_tokens, n_experts_per_tok, hidden_dim)
157        // - indices: (n_tokens, n_experts_per_tok)
158        // - weight: (n_experts, out_features, in_features)
159
160        let (n_tokens, n_experts_per_tok) = indices.dims2()?;
161        let (_n_experts, out_features, _in_features) = weight.dims3()?;
162
163        // Flatten indices to select expert weights
164        let flat_indices = indices.flatten_all()?;
165
166        // Select weights for each (token, expert) pair
167        // weight_selected: (n_tokens * n_experts_per_tok, out_features, in_features)
168        let weight_selected = weight.index_select(&flat_indices, 0)?;
169
170        // Reshape x for batched matmul
171        let x_expanded = if x.dims().len() == 3 && x.dim(1)? == 1 {
172            // x is (n_tokens, 1, hidden_dim) - broadcast to (n_tokens * n_experts_per_tok, 1, hidden_dim)
173            x.squeeze(1)?
174                .unsqueeze(1)?
175                .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
176                .contiguous()?
177        } else if x.dims().len() == 3 {
178            // x is (n_tokens, n_experts_per_tok, hidden_dim)
179            x.reshape((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
180        } else {
181            // x is (n_tokens, hidden_dim)
182            x.unsqueeze(1)?
183                .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(1)?))?
184                .contiguous()?
185        };
186
187        // Batched matmul: (batch, 1, k) @ (batch, k, n).T = (batch, 1, n)
188        // weight_selected is (batch, n, k), so we need to transpose last two dims
189        let weight_t = weight_selected.transpose(1, 2)?;
190        let result = x_expanded.matmul(&weight_t)?;
191
192        // Reshape result to (n_tokens, n_experts_per_tok, out_features)
193        let result = result.reshape((n_tokens, n_experts_per_tok, out_features))?;
194
195        // Apply bias if present
196        if let Some(ref bias) = self.bias {
197            result.broadcast_add(bias)
198        } else {
199            Ok(result)
200        }
201    }
202
203    fn quantized_act_type(&self) -> Option<DType> {
204        None
205    }
206
207    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
208        candle_core::bail!("BlockwiseFP8Linear does not support add_delta_w")
209    }
210
211    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
212        (DType::F8E4M3, self.weight.device().clone())
213    }
214
215    fn apply_isq(
216        self: Arc<Self>,
217        dtype: Option<IsqType>,
218        device: Device,
219        n_quantized: &AtomicUsize,
220        imatrix_weight: Option<Vec<f32>>,
221        guard: QuantizeOntoGuard,
222    ) -> Result<Arc<dyn QuantMethod>> {
223        let weight = ops::fp8_blockwise_dequantize(
224            &self.weight,
225            &self.weight_scale_inv,
226            self.weight_block_size.to_vec(),
227            self.dequant_dtype,
228        )?;
229        match dtype {
230            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
231            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
232                let _acquired_quantize_guard = guard.acquire(&device);
233                if imatrix_weight.is_some() {
234                    // TODO just warn?
235                    candle_core::bail!("HQQ does not support imatrix.");
236                }
237
238                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
239                let bits = match dtype.unwrap() {
240                    IsqType::HQQ8 => HqqBits::Eight,
241                    IsqType::HQQ4 => HqqBits::Four,
242                    // IsqType::HQQ3 => HqqBits::Three,
243                    // IsqType::HQQ2 => HqqBits::Two,
244                    // IsqType::HQQ1 => HqqBits::One,
245                    _ => unreachable!(),
246                };
247                let cfg = HqqConfig {
248                    bits,
249                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
250                    axis: HqqAxis::Zero,
251                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
252                    round_zeros: false,
253                    channel_wise: true,
254                };
255                let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
256                if let Some(bias) = &self.bias {
257                    let bias = bias
258                        .to_device(&device)?
259                        .to_dtype(res.dtype_and_device().0)?;
260                    Ok(Arc::new(res.with_bias(bias)))
261                } else {
262                    Ok(Arc::new(res))
263                }
264            }
265            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
266                let _acquired_quantize_guard = guard.acquire(&device);
267                if imatrix_weight.is_some() {
268                    // TODO just warn?
269                    candle_core::bail!("AFQ does not support imatrix.");
270                }
271
272                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
273                let bits = match dtype.unwrap() {
274                    IsqType::AFQ8 => AfqBits::Eight,
275                    IsqType::AFQ6 => AfqBits::Six,
276                    IsqType::AFQ4 => AfqBits::Four,
277                    IsqType::AFQ3 => AfqBits::Three,
278                    IsqType::AFQ2 => AfqBits::Two,
279                    _ => unreachable!(),
280                };
281
282                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
283                    weight: weight.to_device(&device)?,
284                    bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
285                    bits,
286                    group_size: AfqGroupSize::default(),
287                })?))
288            }
289            Some(
290                IsqType::Q2K
291                | IsqType::Q3K
292                | IsqType::Q4K
293                | IsqType::Q4_0
294                | IsqType::Q4_1
295                | IsqType::Q5K
296                | IsqType::Q5_0
297                | IsqType::Q5_1
298                | IsqType::Q6K
299                | IsqType::Q8K
300                | IsqType::Q8_0
301                | IsqType::Q8_1,
302            ) => {
303                let dtype: GgmlDType = dtype.unwrap().try_into()?;
304                let res = if let Some(imatrix_weight) = imatrix_weight {
305                    generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
306                } else {
307                    generate_isq!(weight, device, dtype, n_quantized, guard)
308                };
309                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
310                    q_weight: res,
311                    b: self
312                        .bias
313                        .as_ref()
314                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
315                })?))
316            }
317            Some(IsqType::F8E4M3) => {
318                let _acquired_quantize_guard = guard.acquire(&device);
319                if imatrix_weight.is_some() {
320                    // TODO just warn?
321                    candle_core::bail!("F8E4M3 does not support imatrix.");
322                }
323
324                let w = weight.to_device(&device)?;
325                let b = if let Some(b) = &self.bias {
326                    Some(b.to_device(&device)?)
327                } else {
328                    None
329                };
330                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
331                    lin: Linear::new(w, b),
332                    dtype: DType::F8E4M3,
333                })?))
334            }
335            None => {
336                let _acquired_quantize_guard = guard.acquire(&device);
337                // Ignore imatrix altogether
338
339                let w = weight.to_device(&device)?;
340                let b = if let Some(b) = &self.bias {
341                    Some(b.to_device(&device)?)
342                } else {
343                    None
344                };
345                Ok(Arc::new(UnquantLinear::new(
346                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
347                )?))
348            }
349        }
350    }
351}
352
353// Serialization structure:
354//
355// -----------------------
356// UQFF version, u32, little endian
357// -----------------------
358// ISQ type (3 for fp8), u8, little endian
359// -----------------------
360// Whether bias data is included, u8 boolean
361// -----------------------
362// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
363// -----------------------
364// Dequant W scalar, f32, little endian
365// -----------------------
366// Dequant X scalar, f32, little endian
367// -----------------------
368// Quant scalar, f32, little endian
369// -----------------------
370// Quantization type, u32, little endian
371// -----------------------
372// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
373// -----------------------
374
375impl QuantizedSerde for BlockwiseFP8Linear {
376    fn isq_serde_supported(&self) -> bool {
377        false
378    }
379    fn name(&self) -> &'static str {
380        "blockwise-fp8-linear"
381    }
382}
383
384/// Create a BlockwiseFP8Linear for MoE with 3D weights [num_experts, N, K].
385/// This is used by FusedExperts to enable gather_forward with native FP8 GEMM.
386pub fn blockwise_fp8_moe(
387    weight: Tensor,
388    weight_scale_inv: Tensor,
389    weight_block_size: Vec<usize>,
390    dequant_dtype: DType,
391) -> Result<Arc<dyn QuantMethod>> {
392    Ok(Arc::new(BlockwiseFP8Linear {
393        weight,
394        weight_scale_inv,
395        bias: None,
396        dequant_dtype,
397        weight_block_size,
398    }))
399}
400
401pub fn blockwise_fp8_linear_b(
402    in_dim: usize,
403    out_dim: usize,
404    config: &QuantizedConfig,
405    bias: bool,
406    hints: Shard,
407    vb: ShardedVarBuilder,
408) -> Result<Arc<dyn QuantMethod>> {
409    let QuantizedConfig::Fp8 { weight_block_size } = config else {
410        candle_core::bail!("Unexpected quantization config.")
411    };
412
413    // Handle the case where we actually have an unqiantzed
414    if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
415        return crate::linear_b(in_dim, out_dim, bias, &None, vb);
416    }
417
418    // Handle the case where the layer is dummy (no tensors)
419    if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
420        let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
421        return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
422    }
423
424    if weight_block_size.len() != 2 {
425        candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
426    }
427    let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
428    let weight_scale_inv = vb.get_with_hints_dtype(
429        (
430            out_dim.div_ceil(weight_block_size[0]),
431            in_dim.div_ceil(weight_block_size[1]),
432        ),
433        "weight_scale_inv",
434        hints,
435        DType::F32,
436    )?;
437    let bias = if bias {
438        Some(vb.get((out_dim,), "bias")?)
439    } else {
440        None
441    };
442
443    Ok(Arc::new(BlockwiseFP8Linear {
444        weight,
445        weight_block_size: weight_block_size.clone(),
446        weight_scale_inv,
447        bias,
448        dequant_dtype: vb.dtype(),
449    }))
450}