mistralrs_quant/mxfp4/
mod.rs

1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{DType, Device, Result, Tensor};
4
5use crate::{
6    AfqBits, AfqGroupSize, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
7    QuantizedConfig, QuantizedSerde, ShardedVarBuilder,
8};
9
10use crate::afq::ops;
11
12const GROUP_SIZE: AfqGroupSize = AfqGroupSize::Low;
13const _: () = assert!(GROUP_SIZE as usize == 32);
14
15const BITS: AfqBits = AfqBits::Mxfp4;
16const _: () = assert!(BITS as usize == 40);
17
18pub(crate) const N_BITS: usize = 4;
19
20#[derive(Debug)]
21pub struct MXFP4Layer {
22    blocks: Tensor,
23    scales: Tensor,
24    bias: Option<Tensor>,
25}
26
27impl QuantMethod for MXFP4Layer {
28    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::GptqAwq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::Unquantized(_)
41            | QuantMethodConfig::Afq { .. } => unreachable!(),
42            QuantMethodConfig::MXFP4 {
43                blocks,
44                scales,
45                bias,
46            } => Ok(Self {
47                blocks,
48                scales,
49                bias,
50            }),
51        }
52    }
53
54    fn dequantize_w(&self) -> Result<candle_core::Tensor> {
55        ops::afq_dequantize_op(
56            &self.blocks,
57            &self.scales,
58            &self.scales.clone(),
59            GROUP_SIZE,
60            BITS,
61        )
62    }
63
64    fn forward(&self, x: &Tensor) -> Result<Tensor> {
65        let mut x = ops::afq_mm_op(
66            x,
67            &self.blocks,
68            &self.scales,
69            &self.scales.clone(),
70            None,
71            None,
72            GROUP_SIZE,
73            BITS,
74            true,
75        )?;
76        if let Some(bias) = &self.bias {
77            x = x.broadcast_add(bias)?;
78        }
79        Ok(x)
80    }
81
82    fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
83        let mut x = ops::afq_mm_op(
84            x,
85            &self.blocks,
86            &self.scales,
87            &self.scales.clone(),
88            None,
89            Some(indices),
90            GROUP_SIZE,
91            BITS,
92            true,
93        )?;
94        if let Some(bias) = &self.bias {
95            x = x.broadcast_add(bias)?;
96        }
97        Ok(x)
98    }
99
100    fn quantized_act_type(&self) -> Option<DType> {
101        None
102    }
103
104    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
105        candle_core::bail!("MXFP4Layer does not support add_delta_w")
106    }
107
108    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
109        (self.scales.dtype(), self.scales.device().clone())
110    }
111
112    fn apply_isq(
113        self: Arc<Self>,
114        _dtype: Option<IsqType>,
115        _device: Device,
116        _n_quantized: &AtomicUsize,
117        _imatrix_weight: Option<Vec<f32>>,
118        _guard: QuantizeOntoGuard,
119    ) -> Result<Arc<dyn QuantMethod>> {
120        todo!()
121    }
122}
123
124impl MXFP4Layer {
125    pub fn linear_b(
126        in_dim: usize,
127        out_dim: usize,
128        config: &QuantizedConfig,
129        bias: bool,
130        vb: ShardedVarBuilder,
131    ) -> Result<Arc<dyn QuantMethod>> {
132        if !vb.device().is_metal() {
133            candle_core::bail!("MXFP4Layer only works on Metal.");
134        }
135
136        let QuantizedConfig::MXFP4 {} = config else {
137            candle_core::bail!("Unexpected quantization config.")
138        };
139
140        let group_size = GROUP_SIZE as usize;
141
142        let blocks = vb.get_with_hints_dtype(
143            (out_dim, in_dim * N_BITS / 32),
144            "blocks",
145            Default::default(),
146            DType::F4,
147        )?;
148        let scales = vb.get_with_hints_dtype(
149            (out_dim, in_dim / group_size),
150            "scales",
151            Default::default(),
152            DType::F8E8M0,
153        )?;
154
155        let bias = if bias {
156            Some(vb.get((out_dim,), "bias")?)
157        } else {
158            None
159        };
160
161        Ok(Arc::new(Self {
162            blocks,
163            scales,
164            bias,
165        }))
166    }
167
168    pub fn packed_linear_b(
169        num_local_experts: usize,
170        in_dim: usize,
171        out_dim: usize,
172        config: &QuantizedConfig,
173        bias: bool,
174        vb: ShardedVarBuilder,
175    ) -> Result<Arc<dyn QuantMethod>> {
176        if !vb.device().is_metal() {
177            candle_core::bail!("MXFP4Layer only works on Metal.");
178        }
179
180        let QuantizedConfig::MXFP4 {} = config else {
181            candle_core::bail!("Unexpected quantization config.")
182        };
183
184        let group_size = GROUP_SIZE as usize;
185
186        let blocks = vb.get_with_hints_dtype(
187            (num_local_experts, out_dim, in_dim * N_BITS / 32),
188            "blocks",
189            Default::default(),
190            DType::F4,
191        )?;
192        let scales = vb.get_with_hints_dtype(
193            (num_local_experts, out_dim, in_dim / group_size),
194            "scales",
195            Default::default(),
196            DType::F8E8M0,
197        )?;
198
199        let bias = if bias {
200            Some(vb.get((num_local_experts, out_dim), "bias")?)
201        } else {
202            None
203        };
204
205        Ok(Arc::new(Self {
206            blocks,
207            scales,
208            bias,
209        }))
210    }
211}
212
213impl QuantizedSerde for MXFP4Layer {
214    fn name(&self) -> &'static str {
215        "mxfp4-layer"
216    }
217    fn isq_serde_supported(&self) -> bool {
218        false
219    }
220}