mistralrs_quant/mxfp4/
mod.rs

1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor};
4
5use crate::{
6    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
7    ShardedVarBuilder,
8};
9
10#[cfg(feature = "cuda")]
11pub(crate) mod ffi;
12#[cfg(feature = "metal")]
13pub(crate) mod metal_ops;
14#[cfg(feature = "cuda")]
15pub(crate) mod ops;
16
17/// MXFP4 block size (32 elements per scale)
18pub const MXFP4_BLOCK_SIZE: usize = 32;
19
20pub(crate) const N_BITS: usize = 4;
21
22#[derive(Debug)]
23pub struct MXFP4Layer {
24    /// Packed FP4 weights: [N, K/2] or [num_experts, N, K/2]
25    /// Each byte contains 2 FP4 values (low nibble = k, high nibble = k+1)
26    #[allow(dead_code)]
27    blocks: Tensor,
28    /// E8M0 scales: [N, K/32] or [num_experts, N, K/32]
29    /// Each byte is an 8-bit exponent with bias 127
30    scales: Tensor,
31    /// Optional bias: [N] or [num_experts, N]
32    #[allow(dead_code)]
33    bias: Option<Tensor>,
34}
35
36impl QuantMethod for MXFP4Layer {
37    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
38    where
39        Self: Sized,
40    {
41        match method {
42            QuantMethodConfig::Gguf { .. }
43            | QuantMethodConfig::GptqAwq { .. }
44            | QuantMethodConfig::Hqq { .. }
45            | QuantMethodConfig::Dummy
46            | QuantMethodConfig::FP8 { .. }
47            | QuantMethodConfig::Bnb { .. }
48            | QuantMethodConfig::BlockwiseFP8 { .. }
49            | QuantMethodConfig::PerTensorFP8 { .. }
50            | QuantMethodConfig::Unquantized(_)
51            | QuantMethodConfig::Afq { .. } => unreachable!(),
52            QuantMethodConfig::MXFP4 {
53                blocks,
54                scales,
55                bias,
56            } => Ok(Self {
57                blocks,
58                scales,
59                bias,
60            }),
61        }
62    }
63
64    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
65        #[cfg(feature = "metal")]
66        if self.blocks.device().is_metal() {
67            use crate::afq::ops;
68            use crate::{AfqBits, AfqGroupSize};
69            return ops::afq_dequantize_op(
70                &self.blocks,
71                &self.scales,
72                &self.scales.clone(),
73                AfqGroupSize::Low,
74                AfqBits::Mxfp4,
75            );
76        }
77        // CPU fallback
78        self.dequantize_weights()
79    }
80
81    #[allow(unused_variables)]
82    fn forward(&self, x: &Tensor) -> Result<Tensor> {
83        #[cfg(feature = "cuda")]
84        if matches!(x.device(), Device::Cuda(_)) && ffi::HAVE_MXFP4_GEMM_KERNELS {
85            let orig_dims = x.dims().to_vec();
86            let x_2d = if orig_dims.len() > 2 {
87                let features = orig_dims[orig_dims.len() - 1];
88                let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
89                x.reshape((batch_size, features))?
90            } else {
91                x.clone()
92            };
93
94            let result = ops::mxfp4_matmul(&x_2d, &self.blocks, &self.scales, self.bias.as_ref())?;
95
96            if orig_dims.len() > 2 {
97                let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
98                new_dims.push(result.dim(1)?);
99                return result.reshape(new_dims);
100            }
101            return Ok(result);
102        }
103
104        #[cfg(feature = "metal")]
105        {
106            if x.device().is_metal() {
107                let orig_dims = x.dims().to_vec();
108                let x_2d = if orig_dims.len() > 2 {
109                    let features = orig_dims[orig_dims.len() - 1];
110                    let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
111                    x.reshape((batch_size, features))?
112                } else {
113                    x.clone()
114                };
115
116                let result =
117                    metal_ops::mxfp4_matmul(&x_2d, &self.blocks, &self.scales, self.bias.as_ref())?;
118
119                if orig_dims.len() > 2 {
120                    let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
121                    new_dims.push(result.dim(1)?);
122                    return result.reshape(new_dims);
123                }
124                return Ok(result);
125            }
126        }
127
128        self.forward_dequantize(x)
129    }
130
131    #[allow(unused_variables)]
132    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
133        #[cfg(feature = "cuda")]
134        if matches!(x.device(), Device::Cuda(_)) && ffi::HAVE_MXFP4_GEMM_KERNELS {
135            return ops::mxfp4_indexed_moe_gemm(
136                x,
137                &self.blocks,
138                &self.scales,
139                self.bias.as_ref(),
140                indices,
141            );
142        }
143
144        #[cfg(feature = "metal")]
145        {
146            if x.device().is_metal() {
147                return metal_ops::mxfp4_indexed_moe_gemm(
148                    x,
149                    &self.blocks,
150                    &self.scales,
151                    self.bias.as_ref(),
152                    indices,
153                );
154            }
155        }
156
157        self.gather_forward_dequantize(x, indices)
158    }
159
160    fn quantized_act_type(&self) -> Option<DType> {
161        None
162    }
163
164    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
165        candle_core::bail!("MXFP4Layer does not support add_delta_w")
166    }
167
168    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
169        (DType::BF16, self.scales.device().clone())
170    }
171
172    fn apply_isq(
173        self: Arc<Self>,
174        _dtype: Option<IsqType>,
175        _device: Device,
176        _n_quantized: &AtomicUsize,
177        _imatrix_weight: Option<Vec<f32>>,
178        _guard: QuantizeOntoGuard,
179    ) -> Result<Arc<dyn QuantMethod>> {
180        candle_core::bail!("MXFP4Layer does not support ISQ")
181    }
182}
183
184impl MXFP4Layer {
185    /// Check if the device supports MXFP4 operations
186    fn device_supported(_device: &Device) -> bool {
187        #[cfg(feature = "cuda")]
188        if matches!(_device, Device::Cuda(_)) {
189            return ffi::HAVE_MXFP4_GEMM_KERNELS;
190        }
191        #[cfg(feature = "metal")]
192        if _device.is_metal() {
193            return true;
194        }
195        false
196    }
197
198    pub fn linear_b(
199        in_dim: usize,
200        out_dim: usize,
201        config: &QuantizedConfig,
202        bias: bool,
203        vb: ShardedVarBuilder,
204    ) -> Result<Arc<dyn QuantMethod>> {
205        if !Self::device_supported(vb.device()) {
206            candle_core::bail!("MXFP4Layer requires CUDA or Metal device.");
207        }
208
209        let QuantizedConfig::MXFP4 {} = config else {
210            candle_core::bail!("Unexpected quantization config.")
211        };
212
213        let blocks = vb.get_with_hints_dtype(
214            (out_dim, in_dim / 2),
215            "blocks",
216            Default::default(),
217            DType::U8,
218        )?;
219        let scales = vb.get_with_hints_dtype(
220            (out_dim, in_dim / MXFP4_BLOCK_SIZE),
221            "scales",
222            Default::default(),
223            DType::U8,
224        )?;
225
226        let bias = if bias {
227            Some(vb.get((out_dim,), "bias")?)
228        } else {
229            None
230        };
231
232        Ok(Arc::new(Self {
233            blocks,
234            scales,
235            bias,
236        }))
237    }
238
239    pub fn packed_linear_b(
240        num_local_experts: usize,
241        in_dim: usize,
242        out_dim: usize,
243        config: &QuantizedConfig,
244        bias: bool,
245        vb: ShardedVarBuilder,
246    ) -> Result<Arc<dyn QuantMethod>> {
247        if !Self::device_supported(vb.device()) {
248            candle_core::bail!("MXFP4Layer requires CUDA or Metal device.");
249        }
250
251        let QuantizedConfig::MXFP4 {} = config else {
252            candle_core::bail!("Unexpected quantization config.")
253        };
254
255        let blocks = vb.get_with_hints_dtype(
256            (num_local_experts, out_dim, in_dim / 2),
257            "blocks",
258            Default::default(),
259            DType::U8,
260        )?;
261        let scales = vb.get_with_hints_dtype(
262            (num_local_experts, out_dim, in_dim / MXFP4_BLOCK_SIZE),
263            "scales",
264            Default::default(),
265            DType::U8,
266        )?;
267
268        let bias = if bias {
269            Some(vb.get((num_local_experts, out_dim), "bias")?)
270        } else {
271            None
272        };
273
274        Ok(Arc::new(Self {
275            blocks,
276            scales,
277            bias,
278        }))
279    }
280
281    /// Load GPT-OSS style MXFP4 experts (combined gate_up_proj format).
282    ///
283    /// GPT-OSS stores tensors as:
284    /// - `{name}_blocks`: [num_experts, out_dim, num_blocks, 16] where 16 bytes = 32 FP4 values
285    /// - `{name}_scales`: [num_experts, out_dim, num_blocks]
286    /// - `{name}_bias`: [num_experts, out_dim]
287    ///
288    /// This function loads and reshapes the 4D blocks tensor to 3D [num_experts, out_dim, in_dim/2].
289    pub fn packed_gptoss_linear(
290        num_local_experts: usize,
291        in_dim: usize,
292        out_dim: usize,
293        bias: bool,
294        name: &str,
295        vb: ShardedVarBuilder,
296    ) -> Result<Arc<dyn QuantMethod>> {
297        if !Self::device_supported(vb.device()) {
298            candle_core::bail!("MXFP4Layer requires CUDA or Metal device.");
299        }
300
301        let num_blocks = in_dim / MXFP4_BLOCK_SIZE;
302
303        let blocks_4d = vb.get_with_hints_dtype(
304            (num_local_experts, out_dim, num_blocks, 16),
305            &format!("{name}_blocks"),
306            Default::default(),
307            DType::U8,
308        )?;
309
310        let blocks = blocks_4d.reshape((num_local_experts, out_dim, num_blocks * 16))?;
311
312        let scales = vb.get_with_hints_dtype(
313            (num_local_experts, out_dim, num_blocks),
314            &format!("{name}_scales"),
315            Default::default(),
316            DType::U8,
317        )?;
318
319        let bias = if bias {
320            Some(vb.get((num_local_experts, out_dim), &format!("{name}_bias"))?)
321        } else {
322            None
323        };
324
325        Ok(Arc::new(Self {
326            blocks,
327            scales,
328            bias,
329        }))
330    }
331
332    /// FP4 E2M1 lookup table for dequantization
333    const FP4_LUT: [f32; 16] = [
334        0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
335    ];
336
337    /// Dequantize MXFP4 weights to f32
338    /// blocks: [num_experts, N, K/2] packed bytes
339    /// scales: [num_experts, N, K/32] E8M0 scales
340    /// Returns: [num_experts, N, K] f32 weights
341    fn dequantize_weights(&self) -> Result<Tensor> {
342        let blocks_dims = self.blocks.dims();
343        let scales_dims = self.scales.dims();
344
345        let (num_experts, n, k_half) = if blocks_dims.len() == 3 {
346            (blocks_dims[0], blocks_dims[1], blocks_dims[2])
347        } else {
348            (1, blocks_dims[0], blocks_dims[1])
349        };
350        let k = k_half * 2;
351
352        let blocks_cpu = self.blocks.to_device(&Device::Cpu)?;
353        let scales_cpu = self.scales.to_device(&Device::Cpu)?;
354
355        let blocks_data: Vec<u8> = blocks_cpu.flatten_all()?.to_vec1()?;
356        let scales_data: Vec<u8> = scales_cpu.flatten_all()?.to_vec1()?;
357
358        let num_scale_blocks = scales_dims[scales_dims.len() - 1];
359        let mut weights = vec![0f32; num_experts * n * k];
360
361        for expert in 0..num_experts {
362            for n_idx in 0..n {
363                for k_idx in 0..k {
364                    let byte_idx = k_idx / 2;
365                    let block_idx = k_idx / MXFP4_BLOCK_SIZE;
366
367                    let blocks_offset = expert * n * k_half + n_idx * k_half + byte_idx;
368                    let scales_offset =
369                        expert * n * num_scale_blocks + n_idx * num_scale_blocks + block_idx;
370
371                    let packed = blocks_data[blocks_offset];
372                    let scale = scales_data[scales_offset];
373
374                    let nibble = if k_idx % 2 == 0 {
375                        packed & 0x0F
376                    } else {
377                        (packed >> 4) & 0x0F
378                    };
379
380                    let base = Self::FP4_LUT[nibble as usize];
381                    let scale_factor = 2f32.powi(scale as i32 - 127);
382                    let value = base * scale_factor;
383
384                    let weight_idx = expert * n * k + n_idx * k + k_idx;
385                    weights[weight_idx] = value;
386                }
387            }
388        }
389
390        let shape = if blocks_dims.len() == 3 {
391            vec![num_experts, n, k]
392        } else {
393            vec![n, k]
394        };
395
396        Tensor::from_vec(weights, shape.as_slice(), &Device::Cpu)?
397            .to_device(self.blocks.device())?
398            .to_dtype(DType::BF16)
399    }
400
401    fn forward_dequantize(&self, x: &Tensor) -> Result<Tensor> {
402        let orig_dims = x.dims().to_vec();
403
404        let x_2d = if orig_dims.len() > 2 {
405            let features = orig_dims[orig_dims.len() - 1];
406            let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
407            x.reshape((batch_size, features))?
408        } else {
409            x.clone()
410        };
411
412        let weights = self.dequantize_weights()?;
413        let weight_t = weights.t()?;
414        let mut result = x_2d.matmul(&weight_t)?;
415
416        if let Some(bias) = &self.bias {
417            result = result.broadcast_add(bias)?;
418        }
419
420        if orig_dims.len() > 2 {
421            let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
422            new_dims.push(result.dim(1)?);
423            result = result.reshape(new_dims)?;
424        }
425
426        Ok(result)
427    }
428
429    fn gather_forward_dequantize(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
430        let x_dims = x.dims();
431        let indices_dims = indices.dims();
432
433        let (num_tokens, topk, _k, x_has_topk) = if x_dims.len() == 2 {
434            (x_dims[0], indices_dims[1], x_dims[1], false)
435        } else {
436            (x_dims[0], x_dims[1], x_dims[2], true)
437        };
438
439        let weights = self.dequantize_weights()?;
440        let weight_dims = weights.dims();
441        let n = weight_dims[1];
442
443        let indices_cpu = indices.to_device(&Device::Cpu)?.to_dtype(DType::U32)?;
444        let indices_data: Vec<u32> = indices_cpu.flatten_all()?.to_vec1()?;
445
446        let mut outputs = Vec::with_capacity(num_tokens * topk);
447
448        for token_idx in 0..num_tokens {
449            for slot_idx in 0..topk {
450                let expert_idx = indices_data[token_idx * topk + slot_idx] as usize;
451
452                let input = if x_has_topk {
453                    x.i((token_idx, slot_idx))?
454                } else {
455                    x.i(token_idx)?
456                };
457
458                let weight = weights.i(expert_idx)?;
459                let input_2d = input.unsqueeze(0)?;
460                let weight_t = weight.t()?;
461                let mut output = input_2d.matmul(&weight_t)?.squeeze(0)?;
462
463                if let Some(bias) = &self.bias {
464                    let expert_bias = bias.i(expert_idx)?;
465                    output = output.broadcast_add(&expert_bias)?;
466                }
467
468                outputs.push(output);
469            }
470        }
471
472        let stacked = Tensor::stack(&outputs, 0)?;
473        stacked.reshape((num_tokens, topk, n))
474    }
475}
476
477impl QuantizedSerde for MXFP4Layer {
478    fn name(&self) -> &'static str {
479        "mxfp4-layer"
480    }
481    fn isq_serde_supported(&self) -> bool {
482        false
483    }
484}