1use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5struct Fp8BlockwiseDequantize {
6 weight_block_size: Vec<usize>,
7 out_ty: DType,
8}
9
10impl Fp8BlockwiseDequantize {
11 fn dispatch_dequant_blockwise<T: WithDType>(
12 &self,
13 weight: &[F8E4M3],
14 scale: &[f32],
15 weight_l: &candle_core::Layout,
16 scale_l: &candle_core::Layout,
17 ) -> candle_core::Result<Vec<T>> {
18 let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]);
19 let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]);
20
21 let res = vec![T::zero(); weight.len()];
22
23 (0..grid_y).into_par_iter().for_each(|y| {
24 (0..grid_x).into_par_iter().for_each(|x| {
25 let res_ptr = res.as_ptr() as *mut T;
26
27 let scale = scale[y * scale_l.stride()[0] + x];
28
29 let start_y = y * self.weight_block_size[0];
30 let end_y = start_y + self.weight_block_size[0];
31
32 let start_x = x * self.weight_block_size[1];
33 let end_x = start_x + self.weight_block_size[1];
34
35 for weight_y in start_y..end_y {
36 if weight_y >= weight_l.dims()[0] {
37 break;
38 }
39
40 let row_offset = weight_y * weight_l.stride()[0];
41 for weight_x in start_x..end_x {
42 if weight_x >= weight_l.dims()[1] {
43 break;
44 }
45
46 let weight_pos = row_offset + weight_x;
47
48 unsafe {
50 *res_ptr.wrapping_add(weight_pos) =
51 T::from_f64((weight[weight_pos].to_f32() * scale) as f64);
52 }
53 }
54 }
55 });
56 });
57
58 Ok(res)
59 }
60}
61
62impl CustomOp2 for Fp8BlockwiseDequantize {
63 fn name(&self) -> &'static str {
64 "fp8-blockwise-dequantize"
65 }
66
67 fn cpu_fwd(
68 &self,
69 scale_s: &candle_core::CpuStorage,
70 scale_l: &candle_core::Layout,
71 weight_s: &candle_core::CpuStorage,
72 weight_l: &candle_core::Layout,
73 ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
74 let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
75 candle_core::bail!("Expected F8E4M3 weight!");
76 };
77 let candle_core::CpuStorage::F32(scale) = scale_s else {
78 candle_core::bail!("Expected F8E4M3 weight!");
79 };
80 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
81 candle_core::bail!("Expected weight to have start offset 0, continuous");
82 }
83 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
84 candle_core::bail!("Expected scales to have start offset 0, continuous");
85 }
86 if weight_l.dims().len() != 2 {
87 candle_core::bail!("Expected weight to be rank 2");
88 }
89 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
90 candle_core::bail!("Expected scale to be rank 2");
91 }
92
93 match self.out_ty {
94 DType::F32 => Ok((
95 CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
96 weight_l.shape().clone(),
97 )),
98 DType::BF16 => Ok((
99 CpuStorage::BF16(
100 self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?,
101 ),
102 weight_l.shape().clone(),
103 )),
104 DType::F16 => Ok((
105 CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
106 weight_l.shape().clone(),
107 )),
108 other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
109 }
110 }
111
112 #[cfg(feature = "cuda")]
113 fn cuda_fwd(
114 &self,
115 scale_s: &candle_core::CudaStorage,
116 scale_l: &candle_core::Layout,
117 weight_s: &candle_core::CudaStorage,
118 weight_l: &candle_core::Layout,
119 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
120 use candle_core::{
121 backend::BackendStorage,
122 cuda::{cudarc::driver::DevicePtr, WrapErr},
123 CudaStorage,
124 };
125 use half::{bf16, f16};
126
127 use crate::blockwise_fp8::ffi;
128
129 if !ffi::HAVE_BLOCKWISE_DEQUANT_KERNELS {
130 candle_core::bail!("Do not have blockwise FP8 dequant kernels.");
131 }
132
133 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
134 candle_core::bail!("Expected weight to have start offset 0, continuous");
135 }
136 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
137 candle_core::bail!("Expected scales to have start offset 0, continuous");
138 }
139 if weight_l.dims().len() != 2 {
140 candle_core::bail!("Expected weight to be rank 2");
141 }
142 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
143 candle_core::bail!("Expected scale to be rank 2");
144 }
145
146 let dev = weight_s.device();
147
148 let weight = weight_s
149 .as_cuda_slice::<F8E4M3>()?
150 .slice(weight_l.start_offset()..);
151 let scale = scale_s
152 .as_cuda_slice::<f32>()?
153 .slice(scale_l.start_offset()..);
154
155 let weight_height = weight_l.dim(0)? as i32;
156 let weight_block_size_x = self.weight_block_size[0] as i32;
157 let weight_width = weight_l.dim(1)? as i32;
158 let weight_block_size_y = self.weight_block_size[1] as i32;
159 let scale_stride = scale_l.stride()[0] as i32;
160 let weight_row_stride = weight_l.stride()[0] as i32;
161
162 let res = match self.out_ty {
163 DType::F32 => {
164 let output = weight_s
165 .device()
166 .alloc_zeros::<f32>(weight_l.shape().elem_count())
167 .w()?;
168 unsafe {
169 ffi::launch_dequant_fp8_blockwise_kernel_f32(
170 (*weight.device_ptr()) as *const _,
171 (*scale.device_ptr()) as *const _,
172 (*output.device_ptr()) as *mut _,
173 weight_height,
174 weight_width,
175 weight_row_stride,
176 scale_stride,
177 weight_block_size_y,
178 weight_block_size_x,
179 *dev.cu_stream(),
180 )
181 };
182 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
183 }
184 DType::F16 => {
185 let output = weight_s
186 .device()
187 .alloc_zeros::<f16>(weight_l.shape().elem_count())
188 .w()?;
189 unsafe {
190 ffi::launch_dequant_fp8_blockwise_kernel_f16(
191 (*weight.device_ptr()) as *const _,
192 (*scale.device_ptr()) as *const _,
193 (*output.device_ptr()) as *mut _,
194 weight_height,
195 weight_width,
196 weight_row_stride,
197 scale_stride,
198 weight_block_size_y,
199 weight_block_size_x,
200 *dev.cu_stream(),
201 )
202 };
203 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
204 }
205 DType::BF16 => {
206 let output = weight_s
207 .device()
208 .alloc_zeros::<bf16>(weight_l.shape().elem_count())
209 .w()?;
210 unsafe {
211 ffi::launch_dequant_fp8_blockwise_kernel_bf16(
212 (*weight.device_ptr()) as *const _,
213 (*scale.device_ptr()) as *const _,
214 (*output.device_ptr()) as *mut _,
215 weight_height,
216 weight_width,
217 weight_row_stride,
218 scale_stride,
219 weight_block_size_y,
220 weight_block_size_x,
221 *dev.cu_stream(),
222 )
223 };
224 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
225 }
226 other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
227 };
228
229 Ok((res, weight_l.shape().clone()))
230 }
231}
232
233pub fn fp8_blockwise_dequantize(
238 weight: &Tensor,
239 inv_scales: &Tensor,
240 weight_block_size: Vec<usize>,
241 out_ty: DType,
242) -> Result<Tensor> {
243 inv_scales.apply_op2_no_bwd(
244 weight,
245 &Fp8BlockwiseDequantize {
246 weight_block_size,
247 out_ty,
248 },
249 )
250}
251
252#[cfg(test)]
253#[allow(unused_imports)]
254mod tests {
255 use candle_core::{DType, Device, Result, Tensor};
256 use candle_nn::{Linear, Module};
257 use half::bf16;
258 use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
259
260 use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors};
261
262 #[test]
263 fn test_fp8_blockwise_dequant() -> Result<()> {
264 let dev = &Device::Cpu;
265 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
266 let weight_block_size = vec![2, 2];
267 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
268
269 let dequant =
270 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
271
272 let res = dequant.to_vec2::<f32>()?;
273 assert_eq!(
274 res,
275 vec![
276 vec![0., 0., 1., 1., 2.],
277 vec![0., 0., 1., 1., 2.],
278 vec![3., 3., 4., 4., 5.],
279 vec![3., 3., 4., 4., 5.],
280 vec![6., 6., 7., 7., 8.],
281 ]
282 );
283
284 Ok(())
285 }
286
287 #[cfg(feature = "cuda")]
288 #[test]
289 fn test_fp8_blockwise_dequant_cuda() -> Result<()> {
290 let truth = {
291 let dev = &Device::Cpu;
292 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
293 let weight_block_size = vec![2, 2];
294 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
295
296 let dequant =
297 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
298
299 dequant.to_vec2::<f32>()?
300 };
301 let test = {
302 let dev = &Device::new_cuda(0)?;
303 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
304 let weight_block_size = vec![2, 2];
305 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
306
307 let dequant =
308 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
309
310 dequant.to_vec2::<f32>()?
311 };
312
313 assert_eq!(test, truth);
314 assert_eq!(
315 test,
316 vec![
317 vec![0., 0., 1., 1., 2.],
318 vec![0., 0., 1., 1., 2.],
319 vec![3., 3., 4., 4., 5.],
320 vec![3., 3., 4., 4., 5.],
321 vec![6., 6., 7., 7., 8.],
322 ]
323 );
324
325 Ok(())
326 }
327
328 #[test]
329 fn test_fp8_blockwise_dequant_bf16() -> Result<()> {
330 let dev = &Device::Cpu;
331 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
332 let weight_block_size = vec![2, 2];
333 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
334
335 let dequant =
336 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::BF16)?;
337
338 let res = dequant.to_vec2::<bf16>()?;
339 assert_eq!(
340 res,
341 vec![
342 vec![
343 bf16::from_f32(0.),
344 bf16::from_f32(0.),
345 bf16::from_f32(1.),
346 bf16::from_f32(1.),
347 bf16::from_f32(2.)
348 ],
349 vec![
350 bf16::from_f32(0.),
351 bf16::from_f32(0.),
352 bf16::from_f32(1.),
353 bf16::from_f32(1.),
354 bf16::from_f32(2.)
355 ],
356 vec![
357 bf16::from_f32(3.),
358 bf16::from_f32(3.),
359 bf16::from_f32(4.),
360 bf16::from_f32(4.),
361 bf16::from_f32(5.)
362 ],
363 vec![
364 bf16::from_f32(3.),
365 bf16::from_f32(3.),
366 bf16::from_f32(4.),
367 bf16::from_f32(4.),
368 bf16::from_f32(5.)
369 ],
370 vec![
371 bf16::from_f32(6.),
372 bf16::from_f32(6.),
373 bf16::from_f32(7.),
374 bf16::from_f32(7.),
375 bf16::from_f32(8.)
376 ],
377 ]
378 );
379
380 Ok(())
381 }
382
383 #[cfg(feature = "cuda")]
384 #[test]
385 fn test_fp8_blockwise_dequant_cuda_bf16() -> Result<()> {
386 let truth = {
387 let dev = &Device::Cpu;
388 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
389 let weight_block_size = vec![2, 2];
390 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
391
392 let dequant = ops::fp8_blockwise_dequantize(
393 &weight,
394 &inv_scales,
395 weight_block_size,
396 DType::BF16,
397 )?;
398
399 dequant.to_vec2::<bf16>()?
400 };
401 let test = {
402 let dev = &Device::new_cuda(0)?;
403 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
404 let weight_block_size = vec![2, 2];
405 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
406
407 let dequant = ops::fp8_blockwise_dequantize(
408 &weight,
409 &inv_scales,
410 weight_block_size,
411 DType::BF16,
412 )?;
413
414 dequant.to_vec2::<bf16>()?
415 };
416
417 assert_eq!(test, truth);
418 assert_eq!(
419 test,
420 vec![
421 vec![
422 bf16::from_f32(0.),
423 bf16::from_f32(0.),
424 bf16::from_f32(1.),
425 bf16::from_f32(1.),
426 bf16::from_f32(2.)
427 ],
428 vec![
429 bf16::from_f32(0.),
430 bf16::from_f32(0.),
431 bf16::from_f32(1.),
432 bf16::from_f32(1.),
433 bf16::from_f32(2.)
434 ],
435 vec![
436 bf16::from_f32(3.),
437 bf16::from_f32(3.),
438 bf16::from_f32(4.),
439 bf16::from_f32(4.),
440 bf16::from_f32(5.)
441 ],
442 vec![
443 bf16::from_f32(3.),
444 bf16::from_f32(3.),
445 bf16::from_f32(4.),
446 bf16::from_f32(4.),
447 bf16::from_f32(5.)
448 ],
449 vec![
450 bf16::from_f32(6.),
451 bf16::from_f32(6.),
452 bf16::from_f32(7.),
453 bf16::from_f32(7.),
454 bf16::from_f32(8.)
455 ],
456 ]
457 );
458
459 Ok(())
460 }
461
462 #[cfg(feature = "cuda")]
463 #[test]
464 fn test_blockwise_fp8_gemm() -> Result<()> {
465 let dev = Device::cuda_if_available(0)?;
466
467 let api = ApiBuilder::new().with_progress(true).build().unwrap();
468 let api = api.repo(Repo::with_revision(
469 "EricB/mistralrs_tests".to_string(),
470 RepoType::Model,
471 "main".to_string(),
472 ));
473
474 let filename = api.get("test_fp8.safetensors").unwrap();
475 let vb = unsafe { MmapedSafetensors::new(filename)? };
476
477 let weight = vb.load("weight", &dev, None)?;
478 assert_eq!((7168, 2048), weight.dims2()?);
479 assert_eq!(DType::F8E4M3, weight.dtype());
480
481 let scale = vb.load("scale", &dev, None)?;
482 assert_eq!((56, 16), scale.dims2()?);
483 assert_eq!(DType::F32, scale.dtype());
484
485 let weight_block_size = vec![128, 128];
486
487 let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?;
489
490 let truth = {
491 let weight_dq =
492 ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?;
493
494 let lin_dq = Linear::new(weight_dq, None);
495 lin_dq.forward(&xs)?
496 };
497
498 assert_eq!((32, 7168), truth.dims2()?);
500
501 Ok(())
502 }
503}