mistralrs_quant/blockwise_fp8/
mod.rs1use std::{
2 num::NonZeroUsize,
3 sync::{atomic::AtomicUsize, Arc},
4};
5
6use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
7use candle_nn::Linear;
8
9mod ops;
10
11#[cfg(feature = "cuda")]
12mod ffi;
13
14use crate::{
15 generate_isq, generate_isq_imatrix,
16 hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
17 DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits, HqqConfig, HqqLayer, IsqType, QuantMethod,
18 QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, Shard,
19 ShardedVarBuilder, UnquantLinear,
20};
21
22#[derive(Debug)]
23pub struct BlockwiseFP8Linear {
24 weight: Tensor,
25 weight_scale_inv: Tensor,
26 bias: Option<Tensor>,
27 dequant_dtype: DType,
28 weight_block_size: Vec<usize>,
29}
30
31impl QuantMethod for BlockwiseFP8Linear {
32 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
33 where
34 Self: Sized,
35 {
36 match method {
37 QuantMethodConfig::Gguf { .. }
38 | QuantMethodConfig::Gptq { .. }
39 | QuantMethodConfig::Hqq { .. }
40 | QuantMethodConfig::Dummy
41 | QuantMethodConfig::Unquantized(_)
42 | QuantMethodConfig::Bnb { .. }
43 | QuantMethodConfig::FP8 { .. } => unreachable!(),
44 QuantMethodConfig::BlockwiseFP8 {
45 weight,
46 weight_scale_inv,
47 bias,
48 dequant_dtype,
49 weight_block_size,
50 } => Ok(Self {
51 weight,
52 weight_scale_inv,
53 bias,
54 dequant_dtype,
55 weight_block_size,
56 }),
57 }
58 }
59 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
60 ops::fp8_blockwise_dequantize(
61 &self.weight,
62 &self.weight_scale_inv,
63 self.weight_block_size.to_vec(),
64 self.dequant_dtype,
65 )
66 }
67
68 fn forward(&self, x: &Tensor) -> Result<Tensor> {
69 let weight = self.dequantize_w()?;
72 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
74 weight,
75 self.bias.clone(),
76 )))?;
77 unquant.forward(x)
78 }
79
80 fn quantized_act_type(&self) -> Option<DType> {
81 None
82 }
83
84 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
85 candle_core::bail!("BlockwiseFP8Linear does not support add_delta_w")
86 }
87
88 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
89 (DType::F8E4M3, self.weight.device().clone())
90 }
91
92 fn apply_isq(
93 self: Arc<Self>,
94 dtype: Option<IsqType>,
95 device: Device,
96 n_quantized: &AtomicUsize,
97 imatrix_weight: Option<Vec<f32>>,
98 guard: QuantizeOntoGuard,
99 ) -> Result<Arc<dyn QuantMethod>> {
100 let weight = ops::fp8_blockwise_dequantize(
101 &self.weight,
102 &self.weight_scale_inv,
103 self.weight_block_size.to_vec(),
104 self.dequant_dtype,
105 )?;
106 match dtype {
107 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
109 let _acquired_quantize_guard = guard.acquire();
110 if imatrix_weight.is_some() {
111 candle_core::bail!("HQQ does not support imatrix.");
113 }
114
115 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
116 let bits = match dtype.unwrap() {
117 IsqType::HQQ8 => HqqBits::Eight,
118 IsqType::HQQ4 => HqqBits::Four,
119 _ => unreachable!(),
123 };
124 let cfg = HqqConfig {
125 bits,
126 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
127 axis: HqqAxis::Zero,
128 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
129 round_zeros: false,
130 channel_wise: true,
131 };
132 let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
133 if let Some(bias) = &self.bias {
134 let bias = bias
135 .to_device(&device)?
136 .to_dtype(res.dtype_and_device().0)?;
137 Ok(Arc::new(res.with_bias(bias)))
138 } else {
139 Ok(Arc::new(res))
140 }
141 }
142 Some(
143 IsqType::Q2K
144 | IsqType::Q3K
145 | IsqType::Q4K
146 | IsqType::Q4_0
147 | IsqType::Q4_1
148 | IsqType::Q5K
149 | IsqType::Q5_0
150 | IsqType::Q5_1
151 | IsqType::Q6K
152 | IsqType::Q8K
153 | IsqType::Q8_0
154 | IsqType::Q8_1,
155 ) => {
156 let dtype: GgmlDType = dtype.unwrap().try_into()?;
157 let res = if let Some(imatrix_weight) = imatrix_weight {
158 generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
159 } else {
160 generate_isq!(weight, device, dtype, n_quantized, guard)
161 };
162 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
163 q_weight: res,
164 b: self
165 .bias
166 .as_ref()
167 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
168 })?))
169 }
170 Some(IsqType::F8E4M3) => {
171 let _acquired_quantize_guard = guard.acquire();
172 if imatrix_weight.is_some() {
173 candle_core::bail!("F8E4M3 does not support imatrix.");
175 }
176
177 let w = weight.to_device(&device)?;
178 let b = if let Some(b) = &self.bias {
179 Some(b.to_device(&device)?)
180 } else {
181 None
182 };
183 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
184 lin: Linear::new(w, b),
185 dtype: DType::F8E4M3,
186 })?))
187 }
188 None => {
189 let _acquired_quantize_guard = guard.acquire();
190 let w = weight.to_device(&device)?;
193 let b = if let Some(b) = &self.bias {
194 Some(b.to_device(&device)?)
195 } else {
196 None
197 };
198 Ok(Arc::new(UnquantLinear::new(
199 QuantMethodConfig::Unquantized(Linear::new(w, b)),
200 )?))
201 }
202 }
203 }
204
205 fn get_max_isq_cpu_threads(&self, dtype: IsqType) -> Option<NonZeroUsize> {
206 match dtype {
207 IsqType::F8E4M3 => None,
208 IsqType::Q2K
209 | IsqType::Q3K
210 | IsqType::Q4K
211 | IsqType::Q4_0
212 | IsqType::Q4_1
213 | IsqType::Q5K
214 | IsqType::Q5_0
215 | IsqType::Q5_1
216 | IsqType::Q6K
217 | IsqType::Q8K
218 | IsqType::Q8_0
219 | IsqType::Q8_1
220 | IsqType::HQQ4
221 | IsqType::HQQ8 => None,
222 }
223 }
224}
225
226impl QuantizedSerde for BlockwiseFP8Linear {
249 fn isq_serde_supported(&self) -> bool {
250 false
251 }
252 fn name(&self) -> &'static str {
253 "blockwise-fp8-linear"
254 }
255}
256
257pub fn blockwise_fp8_linear_b(
258 in_dim: usize,
259 out_dim: usize,
260 config: &QuantizedConfig,
261 bias: bool,
262 hints: Shard,
263 vb: ShardedVarBuilder,
264) -> Result<Arc<dyn QuantMethod>> {
265 if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
267 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
268 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
269 }
270
271 let weight_block_size = config
272 .weight_block_size
273 .as_ref()
274 .expect("Blockwise FP8 requires weight_block_size in config");
275 if weight_block_size.len() != 2 {
276 candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
277 }
278 let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
279 let weight_scale_inv = vb.get_with_hints_dtype(
280 (
281 out_dim.div_ceil(weight_block_size[0]),
282 in_dim.div_ceil(weight_block_size[1]),
283 ),
284 "weight_scale_inv",
285 hints,
286 DType::F32,
287 )?;
288 let bias = if bias {
289 Some(vb.get((out_dim,), "bias")?)
290 } else {
291 None
292 };
293
294 Ok(Arc::new(BlockwiseFP8Linear {
295 weight,
296 weight_block_size: weight_block_size.clone(),
297 weight_scale_inv,
298 bias,
299 dequant_dtype: vb.dtype(),
300 }))
301}