mistralrs_quant/blockwise_fp8/
mod.rs1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
4use candle_nn::Linear;
5
6mod ops;
7
8#[cfg(feature = "cuda")]
9mod ffi;
10
11use crate::{
12 generate_isq, generate_isq_imatrix,
13 hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
14 AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
15 HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
16 QuantizedConfig, QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
17};
18
19#[derive(Debug)]
20pub struct BlockwiseFP8Linear {
21 weight: Tensor,
22 weight_scale_inv: Tensor,
23 bias: Option<Tensor>,
24 dequant_dtype: DType,
25 weight_block_size: Vec<usize>,
26}
27
28impl QuantMethod for BlockwiseFP8Linear {
29 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
30 where
31 Self: Sized,
32 {
33 match method {
34 QuantMethodConfig::Gguf { .. }
35 | QuantMethodConfig::Gptq { .. }
36 | QuantMethodConfig::Hqq { .. }
37 | QuantMethodConfig::Dummy
38 | QuantMethodConfig::Unquantized(_)
39 | QuantMethodConfig::Bnb { .. }
40 | QuantMethodConfig::FP8 { .. }
41 | QuantMethodConfig::Afq { .. } => unreachable!(),
42 QuantMethodConfig::BlockwiseFP8 {
43 weight,
44 weight_scale_inv,
45 bias,
46 dequant_dtype,
47 weight_block_size,
48 } => Ok(Self {
49 weight,
50 weight_scale_inv,
51 bias,
52 dequant_dtype,
53 weight_block_size,
54 }),
55 }
56 }
57 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
58 ops::fp8_blockwise_dequantize(
59 &self.weight,
60 &self.weight_scale_inv,
61 self.weight_block_size.to_vec(),
62 self.dequant_dtype,
63 )
64 }
65
66 fn forward(&self, x: &Tensor) -> Result<Tensor> {
67 let weight = self.dequantize_w()?;
70 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
72 weight,
73 self.bias.clone(),
74 )))?;
75 unquant.forward(x)
76 }
77
78 fn quantized_act_type(&self) -> Option<DType> {
79 None
80 }
81
82 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
83 candle_core::bail!("BlockwiseFP8Linear does not support add_delta_w")
84 }
85
86 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
87 (DType::F8E4M3, self.weight.device().clone())
88 }
89
90 fn apply_isq(
91 self: Arc<Self>,
92 dtype: Option<IsqType>,
93 device: Device,
94 n_quantized: &AtomicUsize,
95 imatrix_weight: Option<Vec<f32>>,
96 guard: QuantizeOntoGuard,
97 ) -> Result<Arc<dyn QuantMethod>> {
98 let weight = ops::fp8_blockwise_dequantize(
99 &self.weight,
100 &self.weight_scale_inv,
101 self.weight_block_size.to_vec(),
102 self.dequant_dtype,
103 )?;
104 match dtype {
105 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
107 let _acquired_quantize_guard = guard.acquire();
108 if imatrix_weight.is_some() {
109 candle_core::bail!("HQQ does not support imatrix.");
111 }
112
113 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
114 let bits = match dtype.unwrap() {
115 IsqType::HQQ8 => HqqBits::Eight,
116 IsqType::HQQ4 => HqqBits::Four,
117 _ => unreachable!(),
121 };
122 let cfg = HqqConfig {
123 bits,
124 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
125 axis: HqqAxis::Zero,
126 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
127 round_zeros: false,
128 channel_wise: true,
129 };
130 let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
131 if let Some(bias) = &self.bias {
132 let bias = bias
133 .to_device(&device)?
134 .to_dtype(res.dtype_and_device().0)?;
135 Ok(Arc::new(res.with_bias(bias)))
136 } else {
137 Ok(Arc::new(res))
138 }
139 }
140 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
141 let _acquired_quantize_guard = guard.acquire();
142 if imatrix_weight.is_some() {
143 candle_core::bail!("AFQ does not support imatrix.");
145 }
146
147 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
148 let bits = match dtype.unwrap() {
149 IsqType::AFQ8 => AfqBits::Eight,
150 IsqType::AFQ6 => AfqBits::Six,
151 IsqType::AFQ4 => AfqBits::Four,
152 IsqType::AFQ3 => AfqBits::Three,
153 IsqType::AFQ2 => AfqBits::Two,
154 _ => unreachable!(),
155 };
156
157 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
158 weight: weight.to_device(&device)?,
159 bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
160 bits,
161 group_size: AfqGroupSize::default(),
162 })?))
163 }
164 Some(
165 IsqType::Q2K
166 | IsqType::Q3K
167 | IsqType::Q4K
168 | IsqType::Q4_0
169 | IsqType::Q4_1
170 | IsqType::Q5K
171 | IsqType::Q5_0
172 | IsqType::Q5_1
173 | IsqType::Q6K
174 | IsqType::Q8K
175 | IsqType::Q8_0
176 | IsqType::Q8_1,
177 ) => {
178 let dtype: GgmlDType = dtype.unwrap().try_into()?;
179 let res = if let Some(imatrix_weight) = imatrix_weight {
180 generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
181 } else {
182 generate_isq!(weight, device, dtype, n_quantized, guard)
183 };
184 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
185 q_weight: res,
186 b: self
187 .bias
188 .as_ref()
189 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
190 })?))
191 }
192 Some(IsqType::F8E4M3) => {
193 let _acquired_quantize_guard = guard.acquire();
194 if imatrix_weight.is_some() {
195 candle_core::bail!("F8E4M3 does not support imatrix.");
197 }
198
199 let w = weight.to_device(&device)?;
200 let b = if let Some(b) = &self.bias {
201 Some(b.to_device(&device)?)
202 } else {
203 None
204 };
205 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
206 lin: Linear::new(w, b),
207 dtype: DType::F8E4M3,
208 })?))
209 }
210 None => {
211 let _acquired_quantize_guard = guard.acquire();
212 let w = weight.to_device(&device)?;
215 let b = if let Some(b) = &self.bias {
216 Some(b.to_device(&device)?)
217 } else {
218 None
219 };
220 Ok(Arc::new(UnquantLinear::new(
221 QuantMethodConfig::Unquantized(Linear::new(w, b)),
222 )?))
223 }
224 }
225 }
226}
227
228impl QuantizedSerde for BlockwiseFP8Linear {
251 fn isq_serde_supported(&self) -> bool {
252 false
253 }
254 fn name(&self) -> &'static str {
255 "blockwise-fp8-linear"
256 }
257}
258
259pub fn blockwise_fp8_linear_b(
260 in_dim: usize,
261 out_dim: usize,
262 config: &QuantizedConfig,
263 bias: bool,
264 hints: Shard,
265 vb: ShardedVarBuilder,
266) -> Result<Arc<dyn QuantMethod>> {
267 let QuantizedConfig::Fp8 { weight_block_size } = config else {
268 candle_core::bail!("Unexpected quantization config.")
269 };
270
271 if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
273 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
274 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
275 }
276
277 if weight_block_size.len() != 2 {
278 candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
279 }
280 let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
281 let weight_scale_inv = vb.get_with_hints_dtype(
282 (
283 out_dim.div_ceil(weight_block_size[0]),
284 in_dim.div_ceil(weight_block_size[1]),
285 ),
286 "weight_scale_inv",
287 hints,
288 DType::F32,
289 )?;
290 let bias = if bias {
291 Some(vb.get((out_dim,), "bias")?)
292 } else {
293 None
294 };
295
296 Ok(Arc::new(BlockwiseFP8Linear {
297 weight,
298 weight_block_size: weight_block_size.clone(),
299 weight_scale_inv,
300 bias,
301 dequant_dtype: vb.dtype(),
302 }))
303}