1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{DType, Device, Result, Tensor, D};
9use candle_nn::{Linear, Module};
10use quantize::QuantizationResult;
11
12mod quantize;
13
14use crate::{
15 cublaslt::{maybe_init_cublas_lt_wrapper, F8MatmulOutType, CUBLASLT_HANDLE},
16 utils::{
17 deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype,
18 UQFF_VERSION,
19 },
20 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
21};
22
23#[derive(Debug)]
24pub struct FP8Linear {
25 lin: Linear,
26 dequant_w_scale: Tensor,
27 dequant_x_scale: Tensor,
28 quant_scale: Tensor,
29 dtype: DType,
31}
32
33impl QuantMethod for FP8Linear {
34 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
35 where
36 Self: Sized,
37 {
38 match method {
39 QuantMethodConfig::Gguf { .. }
40 | QuantMethodConfig::Gptq { .. }
41 | QuantMethodConfig::Hqq { .. }
42 | QuantMethodConfig::Dummy
43 | QuantMethodConfig::Unquantized(_)
44 | QuantMethodConfig::Bnb { .. }
45 | QuantMethodConfig::BlockwiseFP8 { .. }
46 | QuantMethodConfig::Afq { .. } => unreachable!(),
47 QuantMethodConfig::FP8 { lin, dtype } => {
48 let QuantizationResult {
49 qw,
50 quantize_scale,
51 dequantize_scale,
52 } = Self::quantize(lin.weight(), dtype)?;
53 Ok(Self {
54 lin: Linear::new(qw, lin.bias().cloned()),
55 dequant_x_scale: dequantize_scale.clone(), dequant_w_scale: dequantize_scale,
57 quant_scale: quantize_scale,
58 dtype,
59 })
60 }
61 }
62 }
63 fn dequantize_w(&self) -> Result<candle_core::Tensor> {
64 Ok(self.dequantize(DType::F32)?.weight().clone())
65 }
66
67 fn forward(&self, x: &Tensor) -> Result<Tensor> {
68 maybe_init_cublas_lt_wrapper(x.device().clone());
70
71 match *CUBLASLT_HANDLE.lock().unwrap() {
72 Some(handle) => {
73 let n_dims = x.dims().len();
74 if n_dims < 3 {
75 candle_core::bail!(
76 "FP8Linear `matmul` via cuBLASlt expects `x` to have at least 3 dimensions"
77 );
78 }
79 let mut tgt_shape = x.dims().to_vec();
81 *tgt_shape.last_mut().unwrap() = self.lin.weight().dim(0)?;
82
83 let mut x = x.flatten_to(D::Minus(3))?;
85
86 let mut dequant_x_scale = self.dequant_x_scale.clone();
88 if !matches!(x.dtype(), DType::F8E4M3) {
89 let QuantizationResult {
90 qw,
91 quantize_scale: _,
92 dequantize_scale,
93 } = Self::quantize(&x, DType::F8E4M3)?;
94 x = qw;
95 dequant_x_scale = dequantize_scale;
96 }
97
98 let beta = match self.lin.bias().is_some() {
100 true => Some(1.0),
101 false => None,
102 };
103
104 let a = self.lin.weight().unsqueeze(0)?;
106 let b = x;
107
108 handle
109 .batch_matmul_f8(
110 &a,
111 &b,
112 &self.dequant_w_scale,
113 &dequant_x_scale,
114 &self.quant_scale,
115 self.lin.bias(),
116 None,
117 beta,
118 None,
119 None,
120 F8MatmulOutType::BF16, )?
122 .reshape(tgt_shape)
123 }
124 None => {
125 let dequant_x = x.clone();
127 let lin = self.dequantize(x.dtype())?;
128 lin.forward(&dequant_x)
129 }
130 }
131 }
132
133 fn quantized_act_type(&self) -> Option<DType> {
134 None
135 }
136
137 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
138 let dequant = self.dequantize(delta.dtype())?;
139 let new = Linear::new((dequant.weight() + delta)?, dequant.bias().cloned());
140 Ok(Arc::new(Self::new(QuantMethodConfig::FP8 {
141 lin: new,
142 dtype: self.dtype,
143 })?))
144 }
145
146 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
147 (DType::F8E4M3, self.lin.weight().device().clone())
148 }
149
150 fn apply_isq(
151 self: Arc<Self>,
152 _dtype: Option<IsqType>,
153 _device: Device,
154 _n_quantized: &AtomicUsize,
155 _imatrix_weight: Option<Vec<f32>>,
156 _guard: QuantizeOntoGuard,
157 ) -> Result<Arc<dyn QuantMethod>> {
158 todo!()
159 }
160}
161
162impl QuantizedSerde for FP8Linear {
185 fn isq_serde_supported(&self) -> bool {
186 true
187 }
188 fn name(&self) -> &'static str {
189 "fp8-linear"
190 }
191 fn serialize(&self) -> Result<Cow<[u8]>> {
192 self.serialize_with_bias(self.lin.bias().cloned())
193 }
194 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
195 let mut buffer = Vec::new();
196
197 buffer.extend(&UQFF_VERSION.to_le_bytes());
199
200 buffer.push(QuantizedSerdeType::Fp8 as u8);
202
203 buffer.push(bias.is_some() as u8);
205
206 serialize_tensor(&mut buffer, self.lin.weight())?;
208
209 buffer.extend(self.dequant_w_scale.to_scalar::<f32>()?.to_le_bytes());
211 buffer.extend(self.dequant_x_scale.to_scalar::<f32>()?.to_le_bytes());
213 buffer.extend(self.quant_scale.to_scalar::<f32>()?.to_le_bytes());
215
216 write_dtype(self.dtype, &mut buffer);
218
219 if let Some(bias) = &bias {
220 serialize_tensor(&mut buffer, bias)?;
222 }
223
224 Ok(Cow::from(buffer))
225 }
226
227 fn deserialize(
228 data: Cow<[u8]>,
229 device: &Device,
230 _comm: &Arc<crate::Comm>,
231 guard: QuantizeOntoGuard,
232 ) -> Result<Arc<dyn QuantMethod>>
233 where
234 Self: Sized,
235 {
236 let mut buffer = Cursor::new(data.to_vec());
237
238 let version = buffer.read_u32::<LittleEndian>()?;
239 if let Err(e) = version_is_compatible(version) {
240 return Err(candle_core::Error::wrap(e));
241 }
242
243 let isq_type = buffer.read_u8()? as usize;
244 if isq_type != QuantizedSerdeType::Fp8 as usize {
245 candle_core::bail!(
246 "ISQ type ({isq_type}) doesn't match expected type {}",
247 QuantizedSerdeType::Fp8 as usize
248 );
249 }
250
251 let has_bias = buffer.read_u8()? != 0;
252
253 let w = deserialize_tensor(&mut buffer, device)?;
254
255 let _acquired_load_guard = guard.acquire();
256 let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
257 let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
258 let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
259
260 let dtype = read_dtype(&mut buffer)?;
262
263 let b = if has_bias {
264 Some(deserialize_tensor(&mut buffer, device)?)
265 } else {
266 None
267 };
268
269 Ok(Arc::new(Self {
270 lin: Linear::new(w, b),
271 dequant_w_scale,
272 dequant_x_scale,
273 quant_scale,
274 dtype,
275 }))
276 }
277 fn deserialize_ext_bias(
278 data: Cow<[u8]>,
279 device: &Device,
280 guard: QuantizeOntoGuard,
281 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
282 where
283 Self: Sized,
284 {
285 let mut buffer = Cursor::new(data.to_vec());
286
287 let version = buffer.read_u32::<LittleEndian>()?;
288 if let Err(e) = version_is_compatible(version) {
289 return Err(candle_core::Error::wrap(e));
290 }
291
292 let isq_type = buffer.read_u8()? as usize;
293 if isq_type != QuantizedSerdeType::Fp8 as usize {
294 candle_core::bail!(
295 "ISQ type ({isq_type}) doesn't match expected type {}",
296 QuantizedSerdeType::Fp8 as usize
297 );
298 }
299
300 let has_bias = buffer.read_u8()? != 0;
301
302 let _acquired_load_guard = guard.acquire();
303 let w = deserialize_tensor(&mut buffer, device)?;
304
305 let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
306 let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
307 let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
308
309 let dtype = read_dtype(&mut buffer)?;
311
312 let b = if has_bias {
313 Some(deserialize_tensor(&mut buffer, device)?)
314 } else {
315 None
316 };
317
318 Ok((
319 Arc::new(Self {
320 lin: Linear::new(w, None),
321 dequant_w_scale,
322 dequant_x_scale,
323 quant_scale,
324 dtype,
325 }),
326 b,
327 ))
328 }
329}