1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
4use candle_nn::Linear;
5
6mod ops;
7pub use ops::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
8#[cfg(feature = "cuda")]
9#[allow(unused_imports)]
10pub(crate) use ops::{fp8_blockwise_matmul, fp8_indexed_moe_gemm};
11
12#[cfg(feature = "cuda")]
13mod ffi;
14
15use crate::{
16 generate_isq, generate_isq_imatrix,
17 hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
18 AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
19 HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
20 QuantizedConfig, QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
21};
22
23#[derive(Debug)]
24pub struct BlockwiseFP8Linear {
25 weight: Tensor,
26 weight_scale_inv: Tensor,
27 bias: Option<Tensor>,
28 dequant_dtype: DType,
29 weight_block_size: Vec<usize>,
30}
31
32impl QuantMethod for BlockwiseFP8Linear {
33 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
34 where
35 Self: Sized,
36 {
37 match method {
38 QuantMethodConfig::Gguf { .. }
39 | QuantMethodConfig::GptqAwq { .. }
40 | QuantMethodConfig::Hqq { .. }
41 | QuantMethodConfig::Dummy
42 | QuantMethodConfig::Unquantized(_)
43 | QuantMethodConfig::Bnb { .. }
44 | QuantMethodConfig::FP8 { .. }
45 | QuantMethodConfig::PerTensorFP8 { .. }
46 | QuantMethodConfig::Afq { .. }
47 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
48 QuantMethodConfig::BlockwiseFP8 {
49 weight,
50 weight_scale_inv,
51 bias,
52 dequant_dtype,
53 weight_block_size,
54 } => Ok(Self {
55 weight,
56 weight_scale_inv,
57 bias,
58 dequant_dtype,
59 weight_block_size,
60 }),
61 }
62 }
63 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
64 ops::fp8_blockwise_dequantize(
65 &self.weight,
66 &self.weight_scale_inv,
67 self.weight_block_size.to_vec(),
68 self.dequant_dtype,
69 )
70 }
71
72 fn forward(&self, x: &Tensor) -> Result<Tensor> {
73 #[cfg(feature = "cuda")]
75 {
76 if matches!(x.device(), candle_core::Device::Cuda(_))
77 && ffi::HAVE_BLOCKWISE_GEMM_KERNELS
78 {
79 let orig_dims = x.dims().to_vec();
81 let x_2d = if orig_dims.len() > 2 {
82 let features = orig_dims[orig_dims.len() - 1];
84 let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
85 x.reshape((batch_size, features))?
86 } else {
87 x.clone()
88 };
89
90 let result = ops::fp8_blockwise_matmul(
92 &x_2d,
93 &self.weight,
94 &self.weight_scale_inv,
95 &self.weight_block_size,
96 )?;
97
98 let result = if orig_dims.len() > 2 {
100 let out_features = result.dim(1)?;
101 let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
102 new_dims.push(out_features);
103 result.reshape(new_dims)?
104 } else {
105 result
106 };
107
108 if let Some(ref bias) = self.bias {
110 return result.broadcast_add(bias);
111 }
112 return Ok(result);
113 }
114 }
115
116 let weight = self.dequantize_w()?;
118 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
120 weight,
121 self.bias.clone(),
122 )))?;
123 unquant.forward(x)
124 }
125
126 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
131 #[cfg(feature = "cuda")]
133 {
134 if matches!(x.device(), candle_core::Device::Cuda(_))
135 && ffi::HAVE_BLOCKWISE_GEMM_KERNELS
136 {
137 let result = ops::fp8_indexed_moe_gemm(
139 x,
140 &self.weight,
141 &self.weight_scale_inv,
142 indices,
143 &self.weight_block_size,
144 )?;
145 if let Some(ref bias) = self.bias {
147 return result.broadcast_add(bias);
148 }
149 return Ok(result);
150 }
151 }
152
153 let weight = self.dequantize_w()?;
155
156 let (n_tokens, n_experts_per_tok) = indices.dims2()?;
162 let (_n_experts, out_features, _in_features) = weight.dims3()?;
163
164 let flat_indices = indices.flatten_all()?;
166
167 let weight_selected = weight.index_select(&flat_indices, 0)?;
170
171 let x_expanded = if x.dims().len() == 3 && x.dim(1)? == 1 {
173 x.squeeze(1)?
175 .unsqueeze(1)?
176 .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
177 .contiguous()?
178 } else if x.dims().len() == 3 {
179 x.reshape((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
181 } else {
182 x.unsqueeze(1)?
184 .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(1)?))?
185 .contiguous()?
186 };
187
188 let weight_t = weight_selected.transpose(1, 2)?;
191 let result = x_expanded.matmul(&weight_t)?;
192
193 let result = result.reshape((n_tokens, n_experts_per_tok, out_features))?;
195
196 if let Some(ref bias) = self.bias {
198 result.broadcast_add(bias)
199 } else {
200 Ok(result)
201 }
202 }
203
204 fn quantized_act_type(&self) -> Option<DType> {
205 None
206 }
207
208 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
209 candle_core::bail!("BlockwiseFP8Linear does not support add_delta_w")
210 }
211
212 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
213 (DType::F8E4M3, self.weight.device().clone())
214 }
215
216 fn apply_isq(
217 self: Arc<Self>,
218 dtype: Option<IsqType>,
219 device: Device,
220 n_quantized: &AtomicUsize,
221 imatrix_weight: Option<Vec<f32>>,
222 guard: QuantizeOntoGuard,
223 ) -> Result<Arc<dyn QuantMethod>> {
224 let weight = ops::fp8_blockwise_dequantize(
225 &self.weight,
226 &self.weight_scale_inv,
227 self.weight_block_size.to_vec(),
228 self.dequant_dtype,
229 )?;
230 match dtype {
231 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
233 let _acquired_quantize_guard = guard.acquire(&device);
234 if imatrix_weight.is_some() {
235 candle_core::bail!("HQQ does not support imatrix.");
237 }
238
239 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
240 let bits = match dtype.unwrap() {
241 IsqType::HQQ8 => HqqBits::Eight,
242 IsqType::HQQ4 => HqqBits::Four,
243 _ => unreachable!(),
247 };
248 let cfg = HqqConfig {
249 bits,
250 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
251 axis: HqqAxis::Zero,
252 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
253 round_zeros: false,
254 channel_wise: true,
255 };
256 let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
257 if let Some(bias) = &self.bias {
258 let bias = bias
259 .to_device(&device)?
260 .to_dtype(res.dtype_and_device().0)?;
261 Ok(Arc::new(res.with_bias(bias)))
262 } else {
263 Ok(Arc::new(res))
264 }
265 }
266 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
267 let _acquired_quantize_guard = guard.acquire(&device);
268 if imatrix_weight.is_some() {
269 candle_core::bail!("AFQ does not support imatrix.");
271 }
272
273 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
274 let bits = match dtype.unwrap() {
275 IsqType::AFQ8 => AfqBits::Eight,
276 IsqType::AFQ6 => AfqBits::Six,
277 IsqType::AFQ4 => AfqBits::Four,
278 IsqType::AFQ3 => AfqBits::Three,
279 IsqType::AFQ2 => AfqBits::Two,
280 _ => unreachable!(),
281 };
282
283 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
284 weight: weight.to_device(&device)?,
285 bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
286 bits,
287 group_size: AfqGroupSize::default(),
288 })?))
289 }
290 Some(
291 IsqType::Q2K
292 | IsqType::Q3K
293 | IsqType::Q4K
294 | IsqType::Q4_0
295 | IsqType::Q4_1
296 | IsqType::Q5K
297 | IsqType::Q5_0
298 | IsqType::Q5_1
299 | IsqType::Q6K
300 | IsqType::Q8K
301 | IsqType::Q8_0
302 | IsqType::Q8_1,
303 ) => {
304 let dtype: GgmlDType = dtype.unwrap().try_into()?;
305 let res = if let Some(imatrix_weight) = imatrix_weight {
306 generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
307 } else {
308 generate_isq!(weight, device, dtype, n_quantized, guard)
309 };
310 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
311 q_weight: res,
312 b: self
313 .bias
314 .as_ref()
315 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
316 })?))
317 }
318 Some(IsqType::F8E4M3) => {
319 let _acquired_quantize_guard = guard.acquire(&device);
320 if imatrix_weight.is_some() {
321 candle_core::bail!("F8E4M3 does not support imatrix.");
323 }
324
325 let w = weight.to_device(&device)?;
326 let b = if let Some(b) = &self.bias {
327 Some(b.to_device(&device)?)
328 } else {
329 None
330 };
331 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
332 lin: Linear::new(w, b),
333 dtype: DType::F8E4M3,
334 })?))
335 }
336 None => {
337 let _acquired_quantize_guard = guard.acquire(&device);
338 let w = weight.to_device(&device)?;
341 let b = if let Some(b) = &self.bias {
342 Some(b.to_device(&device)?)
343 } else {
344 None
345 };
346 Ok(Arc::new(UnquantLinear::new(
347 QuantMethodConfig::Unquantized(Linear::new(w, b)),
348 )?))
349 }
350 }
351 }
352}
353
354impl QuantizedSerde for BlockwiseFP8Linear {
377 fn isq_serde_supported(&self) -> bool {
378 false
379 }
380 fn name(&self) -> &'static str {
381 "blockwise-fp8-linear"
382 }
383}
384
385pub fn blockwise_fp8_moe(
388 weight: Tensor,
389 weight_scale_inv: Tensor,
390 weight_block_size: Vec<usize>,
391 dequant_dtype: DType,
392) -> Result<Arc<dyn QuantMethod>> {
393 Ok(Arc::new(BlockwiseFP8Linear {
394 weight,
395 weight_scale_inv,
396 bias: None,
397 dequant_dtype,
398 weight_block_size,
399 }))
400}
401
402pub fn blockwise_fp8_linear_b(
403 in_dim: usize,
404 out_dim: usize,
405 config: &QuantizedConfig,
406 bias: bool,
407 hints: Shard,
408 vb: ShardedVarBuilder,
409) -> Result<Arc<dyn QuantMethod>> {
410 let QuantizedConfig::Fp8 { weight_block_size } = config else {
411 candle_core::bail!("Unexpected quantization config.")
412 };
413
414 if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
416 return crate::linear_b(in_dim, out_dim, bias, &None, vb);
417 }
418
419 if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
421 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
422 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
423 }
424
425 let Some(weight_block_size) = weight_block_size else {
427 candle_core::bail!("Blockwise FP8 requires weight_block_size to be set. Use per-tensor FP8 for models without block sizes.")
428 };
429 if weight_block_size.len() != 2 {
430 candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
431 }
432 let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
433 let weight_scale_inv = vb.get_with_hints_dtype(
434 (
435 out_dim.div_ceil(weight_block_size[0]),
436 in_dim.div_ceil(weight_block_size[1]),
437 ),
438 "weight_scale_inv",
439 hints,
440 DType::F32,
441 )?;
442 let bias = if bias {
443 Some(vb.get((out_dim,), "bias")?)
444 } else {
445 None
446 };
447
448 Ok(Arc::new(BlockwiseFP8Linear {
449 weight,
450 weight_block_size: weight_block_size.clone(),
451 weight_scale_inv,
452 bias,
453 dequant_dtype: vb.dtype(),
454 }))
455}