1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{DType, Device, Result, Tensor};
4
5use crate::{
6 AfqBits, AfqGroupSize, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
7 QuantizedConfig, QuantizedSerde, ShardedVarBuilder,
8};
9
10use crate::afq::ops;
11
12const GROUP_SIZE: AfqGroupSize = AfqGroupSize::Low;
13const _: () = assert!(GROUP_SIZE as usize == 32);
14
15const BITS: AfqBits = AfqBits::Mxfp4;
16const _: () = assert!(BITS as usize == 40);
17
18pub(crate) const N_BITS: usize = 4;
19
20#[derive(Debug)]
21pub struct MXFP4Layer {
22 blocks: Tensor,
23 scales: Tensor,
24 bias: Option<Tensor>,
25}
26
27impl QuantMethod for MXFP4Layer {
28 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::GptqAwq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::Unquantized(_)
41 | QuantMethodConfig::Afq { .. } => unreachable!(),
42 QuantMethodConfig::MXFP4 {
43 blocks,
44 scales,
45 bias,
46 } => Ok(Self {
47 blocks,
48 scales,
49 bias,
50 }),
51 }
52 }
53
54 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
55 ops::afq_dequantize_op(
56 &self.blocks,
57 &self.scales,
58 &self.scales.clone(),
59 GROUP_SIZE,
60 BITS,
61 )
62 }
63
64 fn forward(&self, x: &Tensor) -> Result<Tensor> {
65 let mut x = ops::afq_mm_op(
66 x,
67 &self.blocks,
68 &self.scales,
69 &self.scales.clone(),
70 None,
71 None,
72 GROUP_SIZE,
73 BITS,
74 true,
75 )?;
76 if let Some(bias) = &self.bias {
77 x = x.broadcast_add(bias)?;
78 }
79 Ok(x)
80 }
81
82 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
83 let mut x = ops::afq_mm_op(
84 x,
85 &self.blocks,
86 &self.scales,
87 &self.scales.clone(),
88 None,
89 Some(indices),
90 GROUP_SIZE,
91 BITS,
92 true,
93 )?;
94 if let Some(bias) = &self.bias {
95 x = x.broadcast_add(bias)?;
96 }
97 Ok(x)
98 }
99
100 fn quantized_act_type(&self) -> Option<DType> {
101 None
102 }
103
104 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
105 candle_core::bail!("MXFP4Layer does not support add_delta_w")
106 }
107
108 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
109 (self.scales.dtype(), self.scales.device().clone())
110 }
111
112 fn apply_isq(
113 self: Arc<Self>,
114 _dtype: Option<IsqType>,
115 _device: Device,
116 _n_quantized: &AtomicUsize,
117 _imatrix_weight: Option<Vec<f32>>,
118 _guard: QuantizeOntoGuard,
119 ) -> Result<Arc<dyn QuantMethod>> {
120 todo!()
121 }
122}
123
124impl MXFP4Layer {
125 pub fn linear_b(
126 in_dim: usize,
127 out_dim: usize,
128 config: &QuantizedConfig,
129 bias: bool,
130 vb: ShardedVarBuilder,
131 ) -> Result<Arc<dyn QuantMethod>> {
132 if !vb.device().is_metal() {
133 candle_core::bail!("MXFP4Layer only works on Metal.");
134 }
135
136 let QuantizedConfig::MXFP4 {} = config else {
137 candle_core::bail!("Unexpected quantization config.")
138 };
139
140 let group_size = GROUP_SIZE as usize;
141
142 let blocks = vb.get_with_hints_dtype(
143 (out_dim, in_dim * N_BITS / 32),
144 "blocks",
145 Default::default(),
146 DType::F4,
147 )?;
148 let scales = vb.get_with_hints_dtype(
149 (out_dim, in_dim / group_size),
150 "scales",
151 Default::default(),
152 DType::F8E8M0,
153 )?;
154
155 let bias = if bias {
156 Some(vb.get((out_dim,), "bias")?)
157 } else {
158 None
159 };
160
161 Ok(Arc::new(Self {
162 blocks,
163 scales,
164 bias,
165 }))
166 }
167
168 pub fn packed_linear_b(
169 num_local_experts: usize,
170 in_dim: usize,
171 out_dim: usize,
172 config: &QuantizedConfig,
173 bias: bool,
174 vb: ShardedVarBuilder,
175 ) -> Result<Arc<dyn QuantMethod>> {
176 if !vb.device().is_metal() {
177 candle_core::bail!("MXFP4Layer only works on Metal.");
178 }
179
180 let QuantizedConfig::MXFP4 {} = config else {
181 candle_core::bail!("Unexpected quantization config.")
182 };
183
184 let group_size = GROUP_SIZE as usize;
185
186 let blocks = vb.get_with_hints_dtype(
187 (num_local_experts, out_dim, in_dim * N_BITS / 32),
188 "blocks",
189 Default::default(),
190 DType::F4,
191 )?;
192 let scales = vb.get_with_hints_dtype(
193 (num_local_experts, out_dim, in_dim / group_size),
194 "scales",
195 Default::default(),
196 DType::F8E8M0,
197 )?;
198
199 let bias = if bias {
200 Some(vb.get((num_local_experts, out_dim), "bias")?)
201 } else {
202 None
203 };
204
205 Ok(Arc::new(Self {
206 blocks,
207 scales,
208 bias,
209 }))
210 }
211}
212
213impl QuantizedSerde for MXFP4Layer {
214 fn name(&self) -> &'static str {
215 "mxfp4-layer"
216 }
217 fn isq_serde_supported(&self) -> bool {
218 false
219 }
220}