1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor, D};
9use candle_nn::{Linear, Module};
10use quantize::QuantizationResult;
11
12mod quantize;
13
14use crate::{
15 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
16 utils::{
17 deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype,
18 UQFF_VERSION,
19 },
20 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
21};
22
23#[derive(Debug)]
24pub struct FP8Linear {
25 lin: Linear,
26 dequant_w_scale: Tensor,
27 dequant_x_scale: Tensor,
28 quant_scale: Tensor,
29 dtype: DType,
31}
32
33impl QuantMethod for FP8Linear {
34 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
35 where
36 Self: Sized,
37 {
38 match method {
39 QuantMethodConfig::Gguf { .. }
40 | QuantMethodConfig::GptqAwq { .. }
41 | QuantMethodConfig::Hqq { .. }
42 | QuantMethodConfig::Dummy
43 | QuantMethodConfig::Unquantized(_)
44 | QuantMethodConfig::Bnb { .. }
45 | QuantMethodConfig::BlockwiseFP8 { .. }
46 | QuantMethodConfig::PerTensorFP8 { .. }
47 | QuantMethodConfig::Afq { .. }
48 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
49 QuantMethodConfig::FP8 { lin, dtype } => {
50 let QuantizationResult {
51 qw,
52 quantize_scale,
53 dequantize_scale,
54 } = Self::quantize(lin.weight(), dtype)?;
55 Ok(Self {
56 lin: Linear::new(qw, lin.bias().cloned()),
57 dequant_x_scale: dequantize_scale.clone(), dequant_w_scale: dequantize_scale,
59 quant_scale: quantize_scale,
60 dtype,
61 })
62 }
63 }
64 }
65 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
66 Ok(self.dequantize(DType::F32)?.weight().clone())
67 }
68
69 fn forward(&self, x: &Tensor) -> Result<Tensor> {
70 maybe_init_cublas_lt_wrapper(x.device().clone());
72
73 match CUBLASLT_CONTROLLER.get_for_device(x.device()) {
74 Some(handle) => {
75 let n_dims = x.dims().len();
76 if n_dims < 3 {
77 candle_core::bail!(
78 "FP8Linear `matmul` via cuBLASlt expects `x` to have at least 3 dimensions"
79 );
80 }
81 let mut tgt_shape = x.dims().to_vec();
83 *tgt_shape.last_mut().unwrap() = self.lin.weight().dim(0)?;
84
85 let mut x = x.flatten_to(D::Minus(3))?;
87
88 let mut dequant_x_scale = self.dequant_x_scale.clone();
90 if !matches!(x.dtype(), DType::F8E4M3) {
91 let QuantizationResult {
92 qw,
93 quantize_scale: _,
94 dequantize_scale,
95 } = Self::quantize(&x, DType::F8E4M3)?;
96 x = qw;
97 dequant_x_scale = dequantize_scale;
98 }
99
100 let beta = match self.lin.bias().is_some() {
102 true => Some(1.0),
103 false => None,
104 };
105
106 let a = self.lin.weight().unsqueeze(0)?;
108 let b = x;
109
110 handle
111 .batch_matmul_f8(
112 &a,
113 &b,
114 &self.dequant_w_scale,
115 &dequant_x_scale,
116 &self.quant_scale,
117 self.lin.bias(),
118 None,
119 beta,
120 None,
121 None,
122 )?
123 .reshape(tgt_shape)
124 }
125 None => {
126 let dequant_x = x.clone();
128 let lin = self.dequantize(x.dtype())?;
129 lin.forward(&dequant_x)
130 }
131 }
132 }
133
134 fn quantized_act_type(&self) -> Option<DType> {
135 None
136 }
137
138 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139 let dequant = self.dequantize(delta.dtype())?;
140 let new = Linear::new((dequant.weight() + delta)?, dequant.bias().cloned());
141 Ok(Arc::new(Self::new(QuantMethodConfig::FP8 {
142 lin: new,
143 dtype: self.dtype,
144 })?))
145 }
146
147 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
148 (DType::F8E4M3, self.lin.weight().device().clone())
149 }
150
151 fn apply_isq(
152 self: Arc<Self>,
153 _dtype: Option<IsqType>,
154 _device: Device,
155 _n_quantized: &AtomicUsize,
156 _imatrix_weight: Option<Vec<f32>>,
157 _guard: QuantizeOntoGuard,
158 ) -> Result<Arc<dyn QuantMethod>> {
159 todo!()
160 }
161}
162
163impl QuantizedSerde for FP8Linear {
186 fn isq_serde_supported(&self) -> bool {
187 true
188 }
189 fn name(&self) -> &'static str {
190 "fp8-linear"
191 }
192 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
193 self.serialize_with_bias(self.lin.bias().cloned())
194 }
195 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
196 let mut buffer = Vec::new();
197
198 buffer.extend(&UQFF_VERSION.to_le_bytes());
200
201 buffer.push(QuantizedSerdeType::Fp8 as u8);
203
204 buffer.push(bias.is_some() as u8);
206
207 serialize_tensor(&mut buffer, self.lin.weight())?;
209
210 buffer.extend(self.dequant_w_scale.to_scalar::<f32>()?.to_le_bytes());
212 buffer.extend(self.dequant_x_scale.to_scalar::<f32>()?.to_le_bytes());
214 buffer.extend(self.quant_scale.to_scalar::<f32>()?.to_le_bytes());
216
217 write_dtype(self.dtype, &mut buffer);
219
220 if let Some(bias) = &bias {
221 serialize_tensor(&mut buffer, bias)?;
223 }
224
225 Ok(Cow::from(buffer))
226 }
227
228 fn deserialize(
229 data: Cow<[u8]>,
230 device: &Device,
231 _comm: &Arc<crate::Comm>,
232 guard: QuantizeOntoGuard,
233 ) -> Result<Arc<dyn QuantMethod>>
234 where
235 Self: Sized,
236 {
237 let mut buffer = Cursor::new(data.to_vec());
238
239 let version = buffer.read_u32::<LittleEndian>()?;
240 if let Err(e) = version_is_compatible(version) {
241 return Err(candle_core::Error::wrap(e));
242 }
243
244 let isq_type = buffer.read_u8()? as usize;
245 if isq_type != QuantizedSerdeType::Fp8 as usize {
246 candle_core::bail!(
247 "ISQ type ({isq_type}) doesn't match expected type {}",
248 QuantizedSerdeType::Fp8 as usize
249 );
250 }
251
252 let has_bias = buffer.read_u8()? != 0;
253
254 let w = deserialize_tensor(&mut buffer, device)?;
255
256 let _acquired_load_guard = guard.acquire(device);
257 let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
258 let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
259 let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
260
261 let dtype = read_dtype(&mut buffer)?;
263
264 let b = if has_bias {
265 Some(deserialize_tensor(&mut buffer, device)?)
266 } else {
267 None
268 };
269
270 Ok(Arc::new(Self {
271 lin: Linear::new(w, b),
272 dequant_w_scale,
273 dequant_x_scale,
274 quant_scale,
275 dtype,
276 }))
277 }
278 fn deserialize_ext_bias(
279 data: Cow<[u8]>,
280 device: &Device,
281 guard: QuantizeOntoGuard,
282 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
283 where
284 Self: Sized,
285 {
286 let mut buffer = Cursor::new(data.to_vec());
287
288 let version = buffer.read_u32::<LittleEndian>()?;
289 if let Err(e) = version_is_compatible(version) {
290 return Err(candle_core::Error::wrap(e));
291 }
292
293 let isq_type = buffer.read_u8()? as usize;
294 if isq_type != QuantizedSerdeType::Fp8 as usize {
295 candle_core::bail!(
296 "ISQ type ({isq_type}) doesn't match expected type {}",
297 QuantizedSerdeType::Fp8 as usize
298 );
299 }
300
301 let has_bias = buffer.read_u8()? != 0;
302
303 let _acquired_load_guard = guard.acquire(device);
304 let w = deserialize_tensor(&mut buffer, device)?;
305
306 let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
307 let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
308 let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
309
310 let dtype = read_dtype(&mut buffer)?;
312
313 let b = if has_bias {
314 Some(deserialize_tensor(&mut buffer, device)?)
315 } else {
316 None
317 };
318
319 Ok((
320 Arc::new(Self {
321 lin: Linear::new(w, None),
322 dequant_w_scale,
323 dequant_x_scale,
324 quant_scale,
325 dtype,
326 }),
327 b,
328 ))
329 }
330}