1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor};
4
5use crate::{
6 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
7 ShardedVarBuilder,
8};
9
10#[cfg(feature = "cuda")]
11pub(crate) mod ffi;
12#[cfg(feature = "metal")]
13pub(crate) mod metal_ops;
14#[cfg(feature = "cuda")]
15pub(crate) mod ops;
16
17pub const MXFP4_BLOCK_SIZE: usize = 32;
19
20pub(crate) const N_BITS: usize = 4;
21
22#[derive(Debug)]
23pub struct MXFP4Layer {
24 #[allow(dead_code)]
27 blocks: Tensor,
28 scales: Tensor,
31 #[allow(dead_code)]
33 bias: Option<Tensor>,
34}
35
36impl QuantMethod for MXFP4Layer {
37 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
38 where
39 Self: Sized,
40 {
41 match method {
42 QuantMethodConfig::Gguf { .. }
43 | QuantMethodConfig::GptqAwq { .. }
44 | QuantMethodConfig::Hqq { .. }
45 | QuantMethodConfig::Dummy
46 | QuantMethodConfig::FP8 { .. }
47 | QuantMethodConfig::Bnb { .. }
48 | QuantMethodConfig::BlockwiseFP8 { .. }
49 | QuantMethodConfig::PerTensorFP8 { .. }
50 | QuantMethodConfig::Unquantized(_)
51 | QuantMethodConfig::Afq { .. } => unreachable!(),
52 QuantMethodConfig::MXFP4 {
53 blocks,
54 scales,
55 bias,
56 } => Ok(Self {
57 blocks,
58 scales,
59 bias,
60 }),
61 }
62 }
63
64 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
65 #[cfg(feature = "metal")]
66 if self.blocks.device().is_metal() {
67 use crate::afq::ops;
68 use crate::{AfqBits, AfqGroupSize};
69 return ops::afq_dequantize_op(
70 &self.blocks,
71 &self.scales,
72 &self.scales.clone(),
73 AfqGroupSize::Low,
74 AfqBits::Mxfp4,
75 );
76 }
77 self.dequantize_weights()
79 }
80
81 #[allow(unused_variables)]
82 fn forward(&self, x: &Tensor) -> Result<Tensor> {
83 #[cfg(feature = "cuda")]
84 if matches!(x.device(), Device::Cuda(_)) && ffi::HAVE_MXFP4_GEMM_KERNELS {
85 let orig_dims = x.dims().to_vec();
86 let x_2d = if orig_dims.len() > 2 {
87 let features = orig_dims[orig_dims.len() - 1];
88 let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
89 x.reshape((batch_size, features))?
90 } else {
91 x.clone()
92 };
93
94 let result = ops::mxfp4_matmul(&x_2d, &self.blocks, &self.scales, self.bias.as_ref())?;
95
96 if orig_dims.len() > 2 {
97 let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
98 new_dims.push(result.dim(1)?);
99 return result.reshape(new_dims);
100 }
101 return Ok(result);
102 }
103
104 #[cfg(feature = "metal")]
105 {
106 if x.device().is_metal() {
107 let orig_dims = x.dims().to_vec();
108 let x_2d = if orig_dims.len() > 2 {
109 let features = orig_dims[orig_dims.len() - 1];
110 let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
111 x.reshape((batch_size, features))?
112 } else {
113 x.clone()
114 };
115
116 let result =
117 metal_ops::mxfp4_matmul(&x_2d, &self.blocks, &self.scales, self.bias.as_ref())?;
118
119 if orig_dims.len() > 2 {
120 let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
121 new_dims.push(result.dim(1)?);
122 return result.reshape(new_dims);
123 }
124 return Ok(result);
125 }
126 }
127
128 self.forward_dequantize(x)
129 }
130
131 #[allow(unused_variables)]
132 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
133 #[cfg(feature = "cuda")]
134 if matches!(x.device(), Device::Cuda(_)) && ffi::HAVE_MXFP4_GEMM_KERNELS {
135 return ops::mxfp4_indexed_moe_gemm(
136 x,
137 &self.blocks,
138 &self.scales,
139 self.bias.as_ref(),
140 indices,
141 );
142 }
143
144 #[cfg(feature = "metal")]
145 {
146 if x.device().is_metal() {
147 return metal_ops::mxfp4_indexed_moe_gemm(
148 x,
149 &self.blocks,
150 &self.scales,
151 self.bias.as_ref(),
152 indices,
153 );
154 }
155 }
156
157 self.gather_forward_dequantize(x, indices)
158 }
159
160 fn quantized_act_type(&self) -> Option<DType> {
161 None
162 }
163
164 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
165 candle_core::bail!("MXFP4Layer does not support add_delta_w")
166 }
167
168 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
169 (DType::BF16, self.scales.device().clone())
170 }
171
172 fn apply_isq(
173 self: Arc<Self>,
174 _dtype: Option<IsqType>,
175 _device: Device,
176 _n_quantized: &AtomicUsize,
177 _imatrix_weight: Option<Vec<f32>>,
178 _guard: QuantizeOntoGuard,
179 ) -> Result<Arc<dyn QuantMethod>> {
180 candle_core::bail!("MXFP4Layer does not support ISQ")
181 }
182}
183
184impl MXFP4Layer {
185 fn device_supported(_device: &Device) -> bool {
187 #[cfg(feature = "cuda")]
188 if matches!(_device, Device::Cuda(_)) {
189 return ffi::HAVE_MXFP4_GEMM_KERNELS;
190 }
191 #[cfg(feature = "metal")]
192 if _device.is_metal() {
193 return true;
194 }
195 false
196 }
197
198 pub fn linear_b(
199 in_dim: usize,
200 out_dim: usize,
201 config: &QuantizedConfig,
202 bias: bool,
203 vb: ShardedVarBuilder,
204 ) -> Result<Arc<dyn QuantMethod>> {
205 if !Self::device_supported(vb.device()) {
206 candle_core::bail!("MXFP4Layer requires CUDA or Metal device.");
207 }
208
209 let QuantizedConfig::MXFP4 {} = config else {
210 candle_core::bail!("Unexpected quantization config.")
211 };
212
213 let blocks = vb.get_with_hints_dtype(
214 (out_dim, in_dim / 2),
215 "blocks",
216 Default::default(),
217 DType::U8,
218 )?;
219 let scales = vb.get_with_hints_dtype(
220 (out_dim, in_dim / MXFP4_BLOCK_SIZE),
221 "scales",
222 Default::default(),
223 DType::U8,
224 )?;
225
226 let bias = if bias {
227 Some(vb.get((out_dim,), "bias")?)
228 } else {
229 None
230 };
231
232 Ok(Arc::new(Self {
233 blocks,
234 scales,
235 bias,
236 }))
237 }
238
239 pub fn packed_linear_b(
240 num_local_experts: usize,
241 in_dim: usize,
242 out_dim: usize,
243 config: &QuantizedConfig,
244 bias: bool,
245 vb: ShardedVarBuilder,
246 ) -> Result<Arc<dyn QuantMethod>> {
247 if !Self::device_supported(vb.device()) {
248 candle_core::bail!("MXFP4Layer requires CUDA or Metal device.");
249 }
250
251 let QuantizedConfig::MXFP4 {} = config else {
252 candle_core::bail!("Unexpected quantization config.")
253 };
254
255 let blocks = vb.get_with_hints_dtype(
256 (num_local_experts, out_dim, in_dim / 2),
257 "blocks",
258 Default::default(),
259 DType::U8,
260 )?;
261 let scales = vb.get_with_hints_dtype(
262 (num_local_experts, out_dim, in_dim / MXFP4_BLOCK_SIZE),
263 "scales",
264 Default::default(),
265 DType::U8,
266 )?;
267
268 let bias = if bias {
269 Some(vb.get((num_local_experts, out_dim), "bias")?)
270 } else {
271 None
272 };
273
274 Ok(Arc::new(Self {
275 blocks,
276 scales,
277 bias,
278 }))
279 }
280
281 pub fn packed_gptoss_linear(
290 num_local_experts: usize,
291 in_dim: usize,
292 out_dim: usize,
293 bias: bool,
294 name: &str,
295 vb: ShardedVarBuilder,
296 ) -> Result<Arc<dyn QuantMethod>> {
297 if !Self::device_supported(vb.device()) {
298 candle_core::bail!("MXFP4Layer requires CUDA or Metal device.");
299 }
300
301 let num_blocks = in_dim / MXFP4_BLOCK_SIZE;
302
303 let blocks_4d = vb.get_with_hints_dtype(
304 (num_local_experts, out_dim, num_blocks, 16),
305 &format!("{name}_blocks"),
306 Default::default(),
307 DType::U8,
308 )?;
309
310 let blocks = blocks_4d.reshape((num_local_experts, out_dim, num_blocks * 16))?;
311
312 let scales = vb.get_with_hints_dtype(
313 (num_local_experts, out_dim, num_blocks),
314 &format!("{name}_scales"),
315 Default::default(),
316 DType::U8,
317 )?;
318
319 let bias = if bias {
320 Some(vb.get((num_local_experts, out_dim), &format!("{name}_bias"))?)
321 } else {
322 None
323 };
324
325 Ok(Arc::new(Self {
326 blocks,
327 scales,
328 bias,
329 }))
330 }
331
332 const FP4_LUT: [f32; 16] = [
334 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
335 ];
336
337 fn dequantize_weights(&self) -> Result<Tensor> {
342 let blocks_dims = self.blocks.dims();
343 let scales_dims = self.scales.dims();
344
345 let (num_experts, n, k_half) = if blocks_dims.len() == 3 {
346 (blocks_dims[0], blocks_dims[1], blocks_dims[2])
347 } else {
348 (1, blocks_dims[0], blocks_dims[1])
349 };
350 let k = k_half * 2;
351
352 let blocks_cpu = self.blocks.to_device(&Device::Cpu)?;
353 let scales_cpu = self.scales.to_device(&Device::Cpu)?;
354
355 let blocks_data: Vec<u8> = blocks_cpu.flatten_all()?.to_vec1()?;
356 let scales_data: Vec<u8> = scales_cpu.flatten_all()?.to_vec1()?;
357
358 let num_scale_blocks = scales_dims[scales_dims.len() - 1];
359 let mut weights = vec![0f32; num_experts * n * k];
360
361 for expert in 0..num_experts {
362 for n_idx in 0..n {
363 for k_idx in 0..k {
364 let byte_idx = k_idx / 2;
365 let block_idx = k_idx / MXFP4_BLOCK_SIZE;
366
367 let blocks_offset = expert * n * k_half + n_idx * k_half + byte_idx;
368 let scales_offset =
369 expert * n * num_scale_blocks + n_idx * num_scale_blocks + block_idx;
370
371 let packed = blocks_data[blocks_offset];
372 let scale = scales_data[scales_offset];
373
374 let nibble = if k_idx % 2 == 0 {
375 packed & 0x0F
376 } else {
377 (packed >> 4) & 0x0F
378 };
379
380 let base = Self::FP4_LUT[nibble as usize];
381 let scale_factor = 2f32.powi(scale as i32 - 127);
382 let value = base * scale_factor;
383
384 let weight_idx = expert * n * k + n_idx * k + k_idx;
385 weights[weight_idx] = value;
386 }
387 }
388 }
389
390 let shape = if blocks_dims.len() == 3 {
391 vec![num_experts, n, k]
392 } else {
393 vec![n, k]
394 };
395
396 Tensor::from_vec(weights, shape.as_slice(), &Device::Cpu)?
397 .to_device(self.blocks.device())?
398 .to_dtype(DType::BF16)
399 }
400
401 fn forward_dequantize(&self, x: &Tensor) -> Result<Tensor> {
402 let orig_dims = x.dims().to_vec();
403
404 let x_2d = if orig_dims.len() > 2 {
405 let features = orig_dims[orig_dims.len() - 1];
406 let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
407 x.reshape((batch_size, features))?
408 } else {
409 x.clone()
410 };
411
412 let weights = self.dequantize_weights()?;
413 let weight_t = weights.t()?;
414 let mut result = x_2d.matmul(&weight_t)?;
415
416 if let Some(bias) = &self.bias {
417 result = result.broadcast_add(bias)?;
418 }
419
420 if orig_dims.len() > 2 {
421 let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
422 new_dims.push(result.dim(1)?);
423 result = result.reshape(new_dims)?;
424 }
425
426 Ok(result)
427 }
428
429 fn gather_forward_dequantize(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
430 let x_dims = x.dims();
431 let indices_dims = indices.dims();
432
433 let (num_tokens, topk, _k, x_has_topk) = if x_dims.len() == 2 {
434 (x_dims[0], indices_dims[1], x_dims[1], false)
435 } else {
436 (x_dims[0], x_dims[1], x_dims[2], true)
437 };
438
439 let weights = self.dequantize_weights()?;
440 let weight_dims = weights.dims();
441 let n = weight_dims[1];
442
443 let indices_cpu = indices.to_device(&Device::Cpu)?.to_dtype(DType::U32)?;
444 let indices_data: Vec<u32> = indices_cpu.flatten_all()?.to_vec1()?;
445
446 let mut outputs = Vec::with_capacity(num_tokens * topk);
447
448 for token_idx in 0..num_tokens {
449 for slot_idx in 0..topk {
450 let expert_idx = indices_data[token_idx * topk + slot_idx] as usize;
451
452 let input = if x_has_topk {
453 x.i((token_idx, slot_idx))?
454 } else {
455 x.i(token_idx)?
456 };
457
458 let weight = weights.i(expert_idx)?;
459 let input_2d = input.unsqueeze(0)?;
460 let weight_t = weight.t()?;
461 let mut output = input_2d.matmul(&weight_t)?.squeeze(0)?;
462
463 if let Some(bias) = &self.bias {
464 let expert_bias = bias.i(expert_idx)?;
465 output = output.broadcast_add(&expert_bias)?;
466 }
467
468 outputs.push(output);
469 }
470 }
471
472 let stacked = Tensor::stack(&outputs, 0)?;
473 stacked.reshape((num_tokens, topk, n))
474 }
475}
476
477impl QuantizedSerde for MXFP4Layer {
478 fn name(&self) -> &'static str {
479 "mxfp4-layer"
480 }
481 fn isq_serde_supported(&self) -> bool {
482 false
483 }
484}