1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
4use candle_nn::Linear;
5
6mod ops;
7pub use ops::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
8#[cfg(feature = "cuda")]
9#[allow(unused_imports)]
10pub(crate) use ops::{fp8_blockwise_matmul, fp8_indexed_moe_gemm};
11
12#[cfg(feature = "cuda")]
13mod ffi;
14
15use crate::{
16 generate_isq, generate_isq_imatrix,
17 hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
18 AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
19 HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
20 QuantizedConfig, QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
21};
22
23#[derive(Debug)]
24pub struct BlockwiseFP8Linear {
25 weight: Tensor,
26 weight_scale_inv: Tensor,
27 bias: Option<Tensor>,
28 dequant_dtype: DType,
29 weight_block_size: Vec<usize>,
30}
31
32impl QuantMethod for BlockwiseFP8Linear {
33 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
34 where
35 Self: Sized,
36 {
37 match method {
38 QuantMethodConfig::Gguf { .. }
39 | QuantMethodConfig::GptqAwq { .. }
40 | QuantMethodConfig::Hqq { .. }
41 | QuantMethodConfig::Dummy
42 | QuantMethodConfig::Unquantized(_)
43 | QuantMethodConfig::Bnb { .. }
44 | QuantMethodConfig::FP8 { .. }
45 | QuantMethodConfig::Afq { .. }
46 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
47 QuantMethodConfig::BlockwiseFP8 {
48 weight,
49 weight_scale_inv,
50 bias,
51 dequant_dtype,
52 weight_block_size,
53 } => Ok(Self {
54 weight,
55 weight_scale_inv,
56 bias,
57 dequant_dtype,
58 weight_block_size,
59 }),
60 }
61 }
62 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
63 ops::fp8_blockwise_dequantize(
64 &self.weight,
65 &self.weight_scale_inv,
66 self.weight_block_size.to_vec(),
67 self.dequant_dtype,
68 )
69 }
70
71 fn forward(&self, x: &Tensor) -> Result<Tensor> {
72 #[cfg(feature = "cuda")]
74 {
75 if matches!(x.device(), candle_core::Device::Cuda(_))
76 && ffi::HAVE_BLOCKWISE_GEMM_KERNELS
77 {
78 let orig_dims = x.dims().to_vec();
80 let x_2d = if orig_dims.len() > 2 {
81 let features = orig_dims[orig_dims.len() - 1];
83 let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
84 x.reshape((batch_size, features))?
85 } else {
86 x.clone()
87 };
88
89 let result = ops::fp8_blockwise_matmul(
91 &x_2d,
92 &self.weight,
93 &self.weight_scale_inv,
94 &self.weight_block_size,
95 )?;
96
97 let result = if orig_dims.len() > 2 {
99 let out_features = result.dim(1)?;
100 let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
101 new_dims.push(out_features);
102 result.reshape(new_dims)?
103 } else {
104 result
105 };
106
107 if let Some(ref bias) = self.bias {
109 return result.broadcast_add(bias);
110 }
111 return Ok(result);
112 }
113 }
114
115 let weight = self.dequantize_w()?;
117 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
119 weight,
120 self.bias.clone(),
121 )))?;
122 unquant.forward(x)
123 }
124
125 fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
130 #[cfg(feature = "cuda")]
132 {
133 if matches!(x.device(), candle_core::Device::Cuda(_))
134 && ffi::HAVE_BLOCKWISE_GEMM_KERNELS
135 {
136 let result = ops::fp8_indexed_moe_gemm(
138 x,
139 &self.weight,
140 &self.weight_scale_inv,
141 indices,
142 &self.weight_block_size,
143 )?;
144 if let Some(ref bias) = self.bias {
146 return result.broadcast_add(bias);
147 }
148 return Ok(result);
149 }
150 }
151
152 let weight = self.dequantize_w()?;
154
155 let (n_tokens, n_experts_per_tok) = indices.dims2()?;
161 let (_n_experts, out_features, _in_features) = weight.dims3()?;
162
163 let flat_indices = indices.flatten_all()?;
165
166 let weight_selected = weight.index_select(&flat_indices, 0)?;
169
170 let x_expanded = if x.dims().len() == 3 && x.dim(1)? == 1 {
172 x.squeeze(1)?
174 .unsqueeze(1)?
175 .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
176 .contiguous()?
177 } else if x.dims().len() == 3 {
178 x.reshape((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
180 } else {
181 x.unsqueeze(1)?
183 .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(1)?))?
184 .contiguous()?
185 };
186
187 let weight_t = weight_selected.transpose(1, 2)?;
190 let result = x_expanded.matmul(&weight_t)?;
191
192 let result = result.reshape((n_tokens, n_experts_per_tok, out_features))?;
194
195 if let Some(ref bias) = self.bias {
197 result.broadcast_add(bias)
198 } else {
199 Ok(result)
200 }
201 }
202
203 fn quantized_act_type(&self) -> Option<DType> {
204 None
205 }
206
207 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
208 candle_core::bail!("BlockwiseFP8Linear does not support add_delta_w")
209 }
210
211 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
212 (DType::F8E4M3, self.weight.device().clone())
213 }
214
215 fn apply_isq(
216 self: Arc<Self>,
217 dtype: Option<IsqType>,
218 device: Device,
219 n_quantized: &AtomicUsize,
220 imatrix_weight: Option<Vec<f32>>,
221 guard: QuantizeOntoGuard,
222 ) -> Result<Arc<dyn QuantMethod>> {
223 let weight = ops::fp8_blockwise_dequantize(
224 &self.weight,
225 &self.weight_scale_inv,
226 self.weight_block_size.to_vec(),
227 self.dequant_dtype,
228 )?;
229 match dtype {
230 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
232 let _acquired_quantize_guard = guard.acquire(&device);
233 if imatrix_weight.is_some() {
234 candle_core::bail!("HQQ does not support imatrix.");
236 }
237
238 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
239 let bits = match dtype.unwrap() {
240 IsqType::HQQ8 => HqqBits::Eight,
241 IsqType::HQQ4 => HqqBits::Four,
242 _ => unreachable!(),
246 };
247 let cfg = HqqConfig {
248 bits,
249 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
250 axis: HqqAxis::Zero,
251 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
252 round_zeros: false,
253 channel_wise: true,
254 };
255 let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
256 if let Some(bias) = &self.bias {
257 let bias = bias
258 .to_device(&device)?
259 .to_dtype(res.dtype_and_device().0)?;
260 Ok(Arc::new(res.with_bias(bias)))
261 } else {
262 Ok(Arc::new(res))
263 }
264 }
265 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
266 let _acquired_quantize_guard = guard.acquire(&device);
267 if imatrix_weight.is_some() {
268 candle_core::bail!("AFQ does not support imatrix.");
270 }
271
272 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
273 let bits = match dtype.unwrap() {
274 IsqType::AFQ8 => AfqBits::Eight,
275 IsqType::AFQ6 => AfqBits::Six,
276 IsqType::AFQ4 => AfqBits::Four,
277 IsqType::AFQ3 => AfqBits::Three,
278 IsqType::AFQ2 => AfqBits::Two,
279 _ => unreachable!(),
280 };
281
282 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
283 weight: weight.to_device(&device)?,
284 bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
285 bits,
286 group_size: AfqGroupSize::default(),
287 })?))
288 }
289 Some(
290 IsqType::Q2K
291 | IsqType::Q3K
292 | IsqType::Q4K
293 | IsqType::Q4_0
294 | IsqType::Q4_1
295 | IsqType::Q5K
296 | IsqType::Q5_0
297 | IsqType::Q5_1
298 | IsqType::Q6K
299 | IsqType::Q8K
300 | IsqType::Q8_0
301 | IsqType::Q8_1,
302 ) => {
303 let dtype: GgmlDType = dtype.unwrap().try_into()?;
304 let res = if let Some(imatrix_weight) = imatrix_weight {
305 generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
306 } else {
307 generate_isq!(weight, device, dtype, n_quantized, guard)
308 };
309 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
310 q_weight: res,
311 b: self
312 .bias
313 .as_ref()
314 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
315 })?))
316 }
317 Some(IsqType::F8E4M3) => {
318 let _acquired_quantize_guard = guard.acquire(&device);
319 if imatrix_weight.is_some() {
320 candle_core::bail!("F8E4M3 does not support imatrix.");
322 }
323
324 let w = weight.to_device(&device)?;
325 let b = if let Some(b) = &self.bias {
326 Some(b.to_device(&device)?)
327 } else {
328 None
329 };
330 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
331 lin: Linear::new(w, b),
332 dtype: DType::F8E4M3,
333 })?))
334 }
335 None => {
336 let _acquired_quantize_guard = guard.acquire(&device);
337 let w = weight.to_device(&device)?;
340 let b = if let Some(b) = &self.bias {
341 Some(b.to_device(&device)?)
342 } else {
343 None
344 };
345 Ok(Arc::new(UnquantLinear::new(
346 QuantMethodConfig::Unquantized(Linear::new(w, b)),
347 )?))
348 }
349 }
350 }
351}
352
353impl QuantizedSerde for BlockwiseFP8Linear {
376 fn isq_serde_supported(&self) -> bool {
377 false
378 }
379 fn name(&self) -> &'static str {
380 "blockwise-fp8-linear"
381 }
382}
383
384pub fn blockwise_fp8_moe(
387 weight: Tensor,
388 weight_scale_inv: Tensor,
389 weight_block_size: Vec<usize>,
390 dequant_dtype: DType,
391) -> Result<Arc<dyn QuantMethod>> {
392 Ok(Arc::new(BlockwiseFP8Linear {
393 weight,
394 weight_scale_inv,
395 bias: None,
396 dequant_dtype,
397 weight_block_size,
398 }))
399}
400
401pub fn blockwise_fp8_linear_b(
402 in_dim: usize,
403 out_dim: usize,
404 config: &QuantizedConfig,
405 bias: bool,
406 hints: Shard,
407 vb: ShardedVarBuilder,
408) -> Result<Arc<dyn QuantMethod>> {
409 let QuantizedConfig::Fp8 { weight_block_size } = config else {
410 candle_core::bail!("Unexpected quantization config.")
411 };
412
413 if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
415 return crate::linear_b(in_dim, out_dim, bias, &None, vb);
416 }
417
418 if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
420 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
421 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
422 }
423
424 if weight_block_size.len() != 2 {
425 candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
426 }
427 let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
428 let weight_scale_inv = vb.get_with_hints_dtype(
429 (
430 out_dim.div_ceil(weight_block_size[0]),
431 in_dim.div_ceil(weight_block_size[1]),
432 ),
433 "weight_scale_inv",
434 hints,
435 DType::F32,
436 )?;
437 let bias = if bias {
438 Some(vb.get((out_dim,), "bias")?)
439 } else {
440 None
441 };
442
443 Ok(Arc::new(BlockwiseFP8Linear {
444 weight,
445 weight_block_size: weight_block_size.clone(),
446 weight_scale_inv,
447 bias,
448 dequant_dtype: vb.dtype(),
449 }))
450}