mistralrs_quant/pertensor_fp8/
mod.rs1use std::{
2 borrow::Cow,
3 sync::{atomic::AtomicUsize, Arc},
4};
5
6use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
7use candle_nn::Linear;
8
9mod ops;
10
11use crate::{
12 generate_isq, generate_isq_imatrix,
13 hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
14 utils::{serialize_tensor, UQFF_VERSION},
15 AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
16 HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
17 QuantizedConfig, QuantizedSerde, QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
18};
19
20#[derive(Debug)]
28pub struct PerTensorFP8Linear {
29 weight: Tensor,
30 #[allow(dead_code)]
31 weight_scale_inv: Tensor,
32 #[allow(dead_code)]
33 activation_scale: Option<Tensor>,
34 bias: Option<Tensor>,
35 #[allow(dead_code)]
36 dequant_dtype: DType,
37}
38
39impl QuantMethod for PerTensorFP8Linear {
40 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
41 where
42 Self: Sized,
43 {
44 match method {
45 QuantMethodConfig::PerTensorFP8 {
46 weight,
47 weight_scale_inv,
48 activation_scale,
49 bias,
50 dequant_dtype,
51 } => {
52 let dequant_weight =
54 ops::fp8_pertensor_dequantize(&weight, &weight_scale_inv, dequant_dtype)?;
55 Ok(Self {
56 weight: dequant_weight,
57 weight_scale_inv,
58 activation_scale,
59 bias,
60 dequant_dtype,
61 })
62 }
63 _ => unreachable!(),
64 }
65 }
66
67 fn dequantize_w(&self) -> Result<Tensor> {
68 Ok(self.weight.clone())
70 }
71
72 fn forward(&self, x: &Tensor) -> Result<Tensor> {
73 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
75 self.weight.clone(),
76 self.bias.clone(),
77 )))?;
78 unquant.forward(x)
79 }
80
81 fn quantized_act_type(&self) -> Option<DType> {
82 None
83 }
84
85 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
86 candle_core::bail!("PerTensorFP8Linear does not support add_delta_w")
87 }
88
89 fn dtype_and_device(&self) -> (DType, Device) {
90 (DType::F8E4M3, self.weight.device().clone())
91 }
92
93 fn apply_isq(
94 self: Arc<Self>,
95 dtype: Option<IsqType>,
96 device: Device,
97 n_quantized: &AtomicUsize,
98 imatrix_weight: Option<Vec<f32>>,
99 guard: QuantizeOntoGuard,
100 ) -> Result<Arc<dyn QuantMethod>> {
101 let weight = self.dequantize_w()?;
102 match dtype {
103 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
104 let _acquired_quantize_guard = guard.acquire(&device);
105 if imatrix_weight.is_some() {
106 candle_core::bail!("HQQ does not support imatrix.");
107 }
108
109 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
110 let bits = match dtype.unwrap() {
111 IsqType::HQQ8 => HqqBits::Eight,
112 IsqType::HQQ4 => HqqBits::Four,
113 _ => unreachable!(),
114 };
115 let cfg = HqqConfig {
116 bits,
117 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
118 axis: HqqAxis::Zero,
119 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
120 round_zeros: false,
121 channel_wise: true,
122 };
123 let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
124 if let Some(bias) = &self.bias {
125 let bias = bias
126 .to_device(&device)?
127 .to_dtype(res.dtype_and_device().0)?;
128 Ok(Arc::new(res.with_bias(bias)))
129 } else {
130 Ok(Arc::new(res))
131 }
132 }
133 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
134 let _acquired_quantize_guard = guard.acquire(&device);
135 if imatrix_weight.is_some() {
136 candle_core::bail!("AFQ does not support imatrix.");
137 }
138
139 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
140 let bits = match dtype.unwrap() {
141 IsqType::AFQ8 => AfqBits::Eight,
142 IsqType::AFQ6 => AfqBits::Six,
143 IsqType::AFQ4 => AfqBits::Four,
144 IsqType::AFQ3 => AfqBits::Three,
145 IsqType::AFQ2 => AfqBits::Two,
146 _ => unreachable!(),
147 };
148
149 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
150 weight: weight.to_device(&device)?,
151 bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
152 bits,
153 group_size: AfqGroupSize::default(),
154 })?))
155 }
156 Some(
157 IsqType::Q2K
158 | IsqType::Q3K
159 | IsqType::Q4K
160 | IsqType::Q4_0
161 | IsqType::Q4_1
162 | IsqType::Q5K
163 | IsqType::Q5_0
164 | IsqType::Q5_1
165 | IsqType::Q6K
166 | IsqType::Q8K
167 | IsqType::Q8_0
168 | IsqType::Q8_1,
169 ) => {
170 let dtype: GgmlDType = dtype.unwrap().try_into()?;
171 let res = if let Some(imatrix_weight) = imatrix_weight {
172 generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
173 } else {
174 generate_isq!(weight, device, dtype, n_quantized, guard)
175 };
176 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
177 q_weight: res,
178 b: self
179 .bias
180 .as_ref()
181 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
182 })?))
183 }
184 Some(IsqType::F8E4M3) => {
185 let _acquired_quantize_guard = guard.acquire(&device);
186 if imatrix_weight.is_some() {
187 candle_core::bail!("F8E4M3 does not support imatrix.");
188 }
189
190 let w = weight.to_device(&device)?;
191 let b = if let Some(b) = &self.bias {
192 Some(b.to_device(&device)?)
193 } else {
194 None
195 };
196 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
197 lin: Linear::new(w, b),
198 dtype: DType::F8E4M3,
199 })?))
200 }
201 None => {
202 let _acquired_quantize_guard = guard.acquire(&device);
203
204 let w = weight.to_device(&device)?;
205 let b = if let Some(b) = &self.bias {
206 Some(b.to_device(&device)?)
207 } else {
208 None
209 };
210 Ok(Arc::new(UnquantLinear::new(
211 QuantMethodConfig::Unquantized(Linear::new(w, b)),
212 )?))
213 }
214 }
215 }
216}
217
218impl QuantizedSerde for PerTensorFP8Linear {
233 fn isq_serde_supported(&self) -> bool {
234 true
235 }
236 fn name(&self) -> &'static str {
237 "pertensor-fp8-linear"
238 }
239 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
240 self.serialize_with_bias(self.bias.clone())
241 }
242 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
243 let mut buffer = Vec::new();
245
246 buffer.extend(&UQFF_VERSION.to_le_bytes());
248
249 buffer.push(QuantizedSerdeType::Unquant as u8);
251
252 buffer.push(bias.is_some() as u8);
254
255 serialize_tensor(&mut buffer, &self.weight)?;
257
258 if let Some(bias) = &bias {
259 serialize_tensor(&mut buffer, bias)?;
261 }
262
263 Ok(Cow::from(buffer))
264 }
265}
266
267pub fn pertensor_fp8_linear_b(
273 in_dim: usize,
274 out_dim: usize,
275 _config: &QuantizedConfig,
276 bias: bool,
277 _hints: Shard,
278 vb: ShardedVarBuilder,
279) -> Result<Arc<dyn QuantMethod>> {
280 if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
282 return crate::linear_b(in_dim, out_dim, bias, &None, vb);
283 }
284
285 if !vb.contains_tensor("weight") {
287 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
288 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
289 }
290
291 let weight = vb.get_with_hints_dtype(
293 (out_dim, in_dim),
294 "weight",
295 Default::default(),
296 DType::F8E4M3,
297 )?;
298
299 let weight_scale_inv =
301 vb.get_with_hints_dtype((), "weight_scale_inv", Default::default(), DType::F32)?;
302
303 let activation_scale = if vb.contains_tensor("activation_scale") {
305 Some(vb.get_with_hints_dtype((), "activation_scale", Default::default(), DType::F32)?)
306 } else {
307 None
308 };
309
310 let bias = if bias && vb.contains_tensor("bias") {
311 Some(vb.get((out_dim,), "bias")?)
312 } else {
313 None
314 };
315
316 let dequant_dtype = bias.as_ref().map(|b| b.dtype()).unwrap_or(DType::BF16);
320
321 Ok(Arc::new(PerTensorFP8Linear::new(
323 QuantMethodConfig::PerTensorFP8 {
324 weight,
325 weight_scale_inv,
326 activation_scale,
327 bias,
328 dequant_dtype,
329 },
330 )?))
331}