mistralrs_quant/blockwise_fp8/
mod.rs

1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
4use candle_nn::Linear;
5
6mod ops;
7pub use ops::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
8#[cfg(feature = "cuda")]
9#[allow(unused_imports)]
10pub(crate) use ops::{fp8_blockwise_matmul, fp8_indexed_moe_gemm};
11
12#[cfg(feature = "cuda")]
13mod ffi;
14
15use crate::{
16    generate_isq, generate_isq_imatrix,
17    hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
18    AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
19    HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
20    QuantizedConfig, QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
21};
22
23#[derive(Debug)]
24pub struct BlockwiseFP8Linear {
25    weight: Tensor,
26    weight_scale_inv: Tensor,
27    bias: Option<Tensor>,
28    dequant_dtype: DType,
29    weight_block_size: Vec<usize>,
30}
31
32impl QuantMethod for BlockwiseFP8Linear {
33    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
34    where
35        Self: Sized,
36    {
37        match method {
38            QuantMethodConfig::Gguf { .. }
39            | QuantMethodConfig::GptqAwq { .. }
40            | QuantMethodConfig::Hqq { .. }
41            | QuantMethodConfig::Dummy
42            | QuantMethodConfig::Unquantized(_)
43            | QuantMethodConfig::Bnb { .. }
44            | QuantMethodConfig::FP8 { .. }
45            | QuantMethodConfig::PerTensorFP8 { .. }
46            | QuantMethodConfig::Afq { .. }
47            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
48            QuantMethodConfig::BlockwiseFP8 {
49                weight,
50                weight_scale_inv,
51                bias,
52                dequant_dtype,
53                weight_block_size,
54            } => Ok(Self {
55                weight,
56                weight_scale_inv,
57                bias,
58                dequant_dtype,
59                weight_block_size,
60            }),
61        }
62    }
63    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
64        ops::fp8_blockwise_dequantize(
65            &self.weight,
66            &self.weight_scale_inv,
67            self.weight_block_size.to_vec(),
68            self.dequant_dtype,
69        )
70    }
71
72    fn forward(&self, x: &Tensor) -> Result<Tensor> {
73        // Try to use native FP8 GEMM kernel on CUDA
74        #[cfg(feature = "cuda")]
75        {
76            if matches!(x.device(), candle_core::Device::Cuda(_))
77                && ffi::HAVE_BLOCKWISE_GEMM_KERNELS
78            {
79                // Handle batched inputs by flattening to 2D
80                let orig_dims = x.dims().to_vec();
81                let x_2d = if orig_dims.len() > 2 {
82                    // Flatten all but last dim: [batch, seq, features] -> [batch*seq, features]
83                    let features = orig_dims[orig_dims.len() - 1];
84                    let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
85                    x.reshape((batch_size, features))?
86                } else {
87                    x.clone()
88                };
89
90                // Use native FP8 GEMM kernel
91                let result = ops::fp8_blockwise_matmul(
92                    &x_2d,
93                    &self.weight,
94                    &self.weight_scale_inv,
95                    &self.weight_block_size,
96                )?;
97
98                // Reshape back to original batch dimensions
99                let result = if orig_dims.len() > 2 {
100                    let out_features = result.dim(1)?;
101                    let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
102                    new_dims.push(out_features);
103                    result.reshape(new_dims)?
104                } else {
105                    result
106                };
107
108                // Apply bias if present
109                if let Some(ref bias) = self.bias {
110                    return result.broadcast_add(bias);
111                }
112                return Ok(result);
113            }
114        }
115
116        // Fallback: dequantize and use unquantized matmul
117        let weight = self.dequantize_w()?;
118        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
119        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
120            weight,
121            self.bias.clone(),
122        )))?;
123        unquant.forward(x)
124    }
125
126    /// Compute matmul of `self` and `a`. `self` should contain the weights.
127    ///
128    /// If `a` is (n_tokens, 1, cols), `self` weights are (n_experts, rows, cols),
129    /// then the indices are (n_tokens, n_experts_per_tok).
130    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
131        // Try to use native FP8 indexed MoE GEMM kernel on CUDA
132        #[cfg(feature = "cuda")]
133        {
134            if matches!(x.device(), candle_core::Device::Cuda(_))
135                && ffi::HAVE_BLOCKWISE_GEMM_KERNELS
136            {
137                // Use native FP8 indexed MoE GEMM kernel (expects U32 indices)
138                let result = ops::fp8_indexed_moe_gemm(
139                    x,
140                    &self.weight,
141                    &self.weight_scale_inv,
142                    indices,
143                    &self.weight_block_size,
144                )?;
145                // Apply bias if present (broadcast over tokens and topk)
146                if let Some(ref bias) = self.bias {
147                    return result.broadcast_add(bias);
148                }
149                return Ok(result);
150            }
151        }
152
153        // Fallback: dequantize weights and compute manually
154        let weight = self.dequantize_w()?;
155
156        // Expected shapes:
157        // - x: (n_tokens, 1, hidden_dim) or (n_tokens, n_experts_per_tok, hidden_dim)
158        // - indices: (n_tokens, n_experts_per_tok)
159        // - weight: (n_experts, out_features, in_features)
160
161        let (n_tokens, n_experts_per_tok) = indices.dims2()?;
162        let (_n_experts, out_features, _in_features) = weight.dims3()?;
163
164        // Flatten indices to select expert weights
165        let flat_indices = indices.flatten_all()?;
166
167        // Select weights for each (token, expert) pair
168        // weight_selected: (n_tokens * n_experts_per_tok, out_features, in_features)
169        let weight_selected = weight.index_select(&flat_indices, 0)?;
170
171        // Reshape x for batched matmul
172        let x_expanded = if x.dims().len() == 3 && x.dim(1)? == 1 {
173            // x is (n_tokens, 1, hidden_dim) - broadcast to (n_tokens * n_experts_per_tok, 1, hidden_dim)
174            x.squeeze(1)?
175                .unsqueeze(1)?
176                .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
177                .contiguous()?
178        } else if x.dims().len() == 3 {
179            // x is (n_tokens, n_experts_per_tok, hidden_dim)
180            x.reshape((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
181        } else {
182            // x is (n_tokens, hidden_dim)
183            x.unsqueeze(1)?
184                .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(1)?))?
185                .contiguous()?
186        };
187
188        // Batched matmul: (batch, 1, k) @ (batch, k, n).T = (batch, 1, n)
189        // weight_selected is (batch, n, k), so we need to transpose last two dims
190        let weight_t = weight_selected.transpose(1, 2)?;
191        let result = x_expanded.matmul(&weight_t)?;
192
193        // Reshape result to (n_tokens, n_experts_per_tok, out_features)
194        let result = result.reshape((n_tokens, n_experts_per_tok, out_features))?;
195
196        // Apply bias if present
197        if let Some(ref bias) = self.bias {
198            result.broadcast_add(bias)
199        } else {
200            Ok(result)
201        }
202    }
203
204    fn quantized_act_type(&self) -> Option<DType> {
205        None
206    }
207
208    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
209        candle_core::bail!("BlockwiseFP8Linear does not support add_delta_w")
210    }
211
212    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
213        (DType::F8E4M3, self.weight.device().clone())
214    }
215
216    fn apply_isq(
217        self: Arc<Self>,
218        dtype: Option<IsqType>,
219        device: Device,
220        n_quantized: &AtomicUsize,
221        imatrix_weight: Option<Vec<f32>>,
222        guard: QuantizeOntoGuard,
223    ) -> Result<Arc<dyn QuantMethod>> {
224        let weight = ops::fp8_blockwise_dequantize(
225            &self.weight,
226            &self.weight_scale_inv,
227            self.weight_block_size.to_vec(),
228            self.dequant_dtype,
229        )?;
230        match dtype {
231            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
232            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
233                let _acquired_quantize_guard = guard.acquire(&device);
234                if imatrix_weight.is_some() {
235                    // TODO just warn?
236                    candle_core::bail!("HQQ does not support imatrix.");
237                }
238
239                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
240                let bits = match dtype.unwrap() {
241                    IsqType::HQQ8 => HqqBits::Eight,
242                    IsqType::HQQ4 => HqqBits::Four,
243                    // IsqType::HQQ3 => HqqBits::Three,
244                    // IsqType::HQQ2 => HqqBits::Two,
245                    // IsqType::HQQ1 => HqqBits::One,
246                    _ => unreachable!(),
247                };
248                let cfg = HqqConfig {
249                    bits,
250                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
251                    axis: HqqAxis::Zero,
252                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
253                    round_zeros: false,
254                    channel_wise: true,
255                };
256                let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
257                if let Some(bias) = &self.bias {
258                    let bias = bias
259                        .to_device(&device)?
260                        .to_dtype(res.dtype_and_device().0)?;
261                    Ok(Arc::new(res.with_bias(bias)))
262                } else {
263                    Ok(Arc::new(res))
264                }
265            }
266            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
267                let _acquired_quantize_guard = guard.acquire(&device);
268                if imatrix_weight.is_some() {
269                    // TODO just warn?
270                    candle_core::bail!("AFQ does not support imatrix.");
271                }
272
273                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
274                let bits = match dtype.unwrap() {
275                    IsqType::AFQ8 => AfqBits::Eight,
276                    IsqType::AFQ6 => AfqBits::Six,
277                    IsqType::AFQ4 => AfqBits::Four,
278                    IsqType::AFQ3 => AfqBits::Three,
279                    IsqType::AFQ2 => AfqBits::Two,
280                    _ => unreachable!(),
281                };
282
283                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
284                    weight: weight.to_device(&device)?,
285                    bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
286                    bits,
287                    group_size: AfqGroupSize::default(),
288                })?))
289            }
290            Some(
291                IsqType::Q2K
292                | IsqType::Q3K
293                | IsqType::Q4K
294                | IsqType::Q4_0
295                | IsqType::Q4_1
296                | IsqType::Q5K
297                | IsqType::Q5_0
298                | IsqType::Q5_1
299                | IsqType::Q6K
300                | IsqType::Q8K
301                | IsqType::Q8_0
302                | IsqType::Q8_1,
303            ) => {
304                let dtype: GgmlDType = dtype.unwrap().try_into()?;
305                let res = if let Some(imatrix_weight) = imatrix_weight {
306                    generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
307                } else {
308                    generate_isq!(weight, device, dtype, n_quantized, guard)
309                };
310                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
311                    q_weight: res,
312                    b: self
313                        .bias
314                        .as_ref()
315                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
316                })?))
317            }
318            Some(IsqType::F8E4M3) => {
319                let _acquired_quantize_guard = guard.acquire(&device);
320                if imatrix_weight.is_some() {
321                    // TODO just warn?
322                    candle_core::bail!("F8E4M3 does not support imatrix.");
323                }
324
325                let w = weight.to_device(&device)?;
326                let b = if let Some(b) = &self.bias {
327                    Some(b.to_device(&device)?)
328                } else {
329                    None
330                };
331                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
332                    lin: Linear::new(w, b),
333                    dtype: DType::F8E4M3,
334                })?))
335            }
336            None => {
337                let _acquired_quantize_guard = guard.acquire(&device);
338                // Ignore imatrix altogether
339
340                let w = weight.to_device(&device)?;
341                let b = if let Some(b) = &self.bias {
342                    Some(b.to_device(&device)?)
343                } else {
344                    None
345                };
346                Ok(Arc::new(UnquantLinear::new(
347                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
348                )?))
349            }
350        }
351    }
352}
353
354// Serialization structure:
355//
356// -----------------------
357// UQFF version, u32, little endian
358// -----------------------
359// ISQ type (3 for fp8), u8, little endian
360// -----------------------
361// Whether bias data is included, u8 boolean
362// -----------------------
363// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
364// -----------------------
365// Dequant W scalar, f32, little endian
366// -----------------------
367// Dequant X scalar, f32, little endian
368// -----------------------
369// Quant scalar, f32, little endian
370// -----------------------
371// Quantization type, u32, little endian
372// -----------------------
373// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
374// -----------------------
375
376impl QuantizedSerde for BlockwiseFP8Linear {
377    fn isq_serde_supported(&self) -> bool {
378        false
379    }
380    fn name(&self) -> &'static str {
381        "blockwise-fp8-linear"
382    }
383}
384
385/// Create a BlockwiseFP8Linear for MoE with 3D weights [num_experts, N, K].
386/// This is used by FusedExperts to enable gather_forward with native FP8 GEMM.
387pub fn blockwise_fp8_moe(
388    weight: Tensor,
389    weight_scale_inv: Tensor,
390    weight_block_size: Vec<usize>,
391    dequant_dtype: DType,
392) -> Result<Arc<dyn QuantMethod>> {
393    Ok(Arc::new(BlockwiseFP8Linear {
394        weight,
395        weight_scale_inv,
396        bias: None,
397        dequant_dtype,
398        weight_block_size,
399    }))
400}
401
402pub fn blockwise_fp8_linear_b(
403    in_dim: usize,
404    out_dim: usize,
405    config: &QuantizedConfig,
406    bias: bool,
407    hints: Shard,
408    vb: ShardedVarBuilder,
409) -> Result<Arc<dyn QuantMethod>> {
410    let QuantizedConfig::Fp8 { weight_block_size } = config else {
411        candle_core::bail!("Unexpected quantization config.")
412    };
413
414    // Handle the case where we actually have an unquantized layer
415    if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
416        return crate::linear_b(in_dim, out_dim, bias, &None, vb);
417    }
418
419    // Handle the case where the layer is dummy (no tensors)
420    if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
421        let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
422        return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
423    }
424
425    // Blockwise FP8 requires weight_block_size to be set
426    let Some(weight_block_size) = weight_block_size else {
427        candle_core::bail!("Blockwise FP8 requires weight_block_size to be set. Use per-tensor FP8 for models without block sizes.")
428    };
429    if weight_block_size.len() != 2 {
430        candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
431    }
432    let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
433    let weight_scale_inv = vb.get_with_hints_dtype(
434        (
435            out_dim.div_ceil(weight_block_size[0]),
436            in_dim.div_ceil(weight_block_size[1]),
437        ),
438        "weight_scale_inv",
439        hints,
440        DType::F32,
441    )?;
442    let bias = if bias {
443        Some(vb.get((out_dim,), "bias")?)
444    } else {
445        None
446    };
447
448    Ok(Arc::new(BlockwiseFP8Linear {
449        weight,
450        weight_block_size: weight_block_size.clone(),
451        weight_scale_inv,
452        bias,
453        dequant_dtype: vb.dtype(),
454    }))
455}