1use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5struct Fp8BlockwiseDequantize {
6 weight_block_size: Vec<usize>,
7 out_ty: DType,
8}
9
10impl Fp8BlockwiseDequantize {
11 fn dispatch_dequant_blockwise<T: WithDType>(
12 &self,
13 weight: &[F8E4M3],
14 scale: &[f32],
15 weight_l: &candle_core::Layout,
16 scale_l: &candle_core::Layout,
17 ) -> candle_core::Result<Vec<T>> {
18 let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]);
19 let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]);
20
21 let res = vec![T::zero(); weight.len()];
22
23 (0..grid_y).into_par_iter().for_each(|y| {
24 (0..grid_x).into_par_iter().for_each(|x| {
25 let res_ptr = res.as_ptr() as *mut T;
26
27 let scale = scale[y * scale_l.stride()[0] + x];
28
29 let start_y = y * self.weight_block_size[0];
30 let end_y = start_y + self.weight_block_size[0];
31
32 let start_x = x * self.weight_block_size[1];
33 let end_x = start_x + self.weight_block_size[1];
34
35 for weight_y in start_y..end_y {
36 if weight_y >= weight_l.dims()[0] {
37 break;
38 }
39
40 let row_offset = weight_y * weight_l.stride()[0];
41 for weight_x in start_x..end_x {
42 if weight_x >= weight_l.dims()[1] {
43 break;
44 }
45
46 let weight_pos = row_offset + weight_x;
47
48 unsafe {
50 *res_ptr.wrapping_add(weight_pos) =
51 T::from_f64((weight[weight_pos].to_f32() * scale) as f64);
52 }
53 }
54 }
55 });
56 });
57
58 Ok(res)
59 }
60}
61
62impl CustomOp2 for Fp8BlockwiseDequantize {
63 fn name(&self) -> &'static str {
64 "fp8-blockwise-dequantize"
65 }
66
67 fn cpu_fwd(
68 &self,
69 scale_s: &candle_core::CpuStorage,
70 scale_l: &candle_core::Layout,
71 weight_s: &candle_core::CpuStorage,
72 weight_l: &candle_core::Layout,
73 ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
74 let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
75 candle_core::bail!("Expected F8E4M3 weight!");
76 };
77 let candle_core::CpuStorage::F32(scale) = scale_s else {
78 candle_core::bail!("Expected F8E4M3 weight!");
79 };
80 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
81 candle_core::bail!("Expected weight to have start offset 0, continuous");
82 }
83 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
84 candle_core::bail!("Expected scales to have start offset 0, continuous");
85 }
86 if weight_l.dims().len() != 2 {
87 candle_core::bail!("Expected weight to be rank 2");
88 }
89 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
90 candle_core::bail!("Expected scale to be rank 2");
91 }
92
93 match self.out_ty {
94 DType::F32 => Ok((
95 CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
96 weight_l.shape().clone(),
97 )),
98 DType::BF16 => Ok((
99 CpuStorage::BF16(
100 self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?,
101 ),
102 weight_l.shape().clone(),
103 )),
104 DType::F16 => Ok((
105 CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
106 weight_l.shape().clone(),
107 )),
108 other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
109 }
110 }
111
112 #[cfg(feature = "cuda")]
113 fn cuda_fwd(
114 &self,
115 scale_s: &candle_core::CudaStorage,
116 scale_l: &candle_core::Layout,
117 weight_s: &candle_core::CudaStorage,
118 weight_l: &candle_core::Layout,
119 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
120 use candle_core::{
121 backend::BackendStorage,
122 cuda::{cudarc::driver::DevicePtr, WrapErr},
123 CudaStorage,
124 };
125 use half::{bf16, f16};
126
127 use crate::blockwise_fp8::ffi;
128
129 if !ffi::HAVE_BLOCKWISE_DEQUANT_KERNELS {
130 candle_core::bail!("Do not have blockwise FP8 dequant kernels.");
131 }
132
133 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
134 candle_core::bail!("Expected weight to have start offset 0, continuous");
135 }
136 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
137 candle_core::bail!("Expected scales to have start offset 0, continuous");
138 }
139 if weight_l.dims().len() != 2 {
140 candle_core::bail!("Expected weight to be rank 2");
141 }
142 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
143 candle_core::bail!("Expected scale to be rank 2");
144 }
145
146 let dev = weight_s.device();
147
148 let weight = weight_s
149 .as_cuda_slice::<F8E4M3>()?
150 .slice(weight_l.start_offset()..);
151 let scale = scale_s
152 .as_cuda_slice::<f32>()?
153 .slice(scale_l.start_offset()..);
154
155 let weight_height = weight_l.dim(0)? as i32;
156 let weight_block_size_x = self.weight_block_size[0] as i32;
157 let weight_width = weight_l.dim(1)? as i32;
158 let weight_block_size_y = self.weight_block_size[1] as i32;
159 let scale_stride = scale_l.stride()[0] as i32;
160 let weight_row_stride = weight_l.stride()[0] as i32;
161
162 let res = match self.out_ty {
163 DType::F32 => {
164 let output = weight_s
165 .device()
166 .alloc_zeros::<f32>(weight_l.shape().elem_count())
167 .w()?;
168 unsafe {
169 ffi::launch_dequant_fp8_blockwise_kernel_f32(
170 (*weight.device_ptr()) as *const _,
171 (*scale.device_ptr()) as *const _,
172 (*output.device_ptr()) as *mut _,
173 weight_height,
174 weight_width,
175 weight_row_stride,
176 scale_stride,
177 weight_block_size_y,
178 weight_block_size_x,
179 *dev.cu_stream(),
180 )
181 };
182 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
183 }
184 DType::F16 => {
185 let output = weight_s
186 .device()
187 .alloc_zeros::<f16>(weight_l.shape().elem_count())
188 .w()?;
189 unsafe {
190 ffi::launch_dequant_fp8_blockwise_kernel_f16(
191 (*weight.device_ptr()) as *const _,
192 (*scale.device_ptr()) as *const _,
193 (*output.device_ptr()) as *mut _,
194 weight_height,
195 weight_width,
196 weight_row_stride,
197 scale_stride,
198 weight_block_size_y,
199 weight_block_size_x,
200 *dev.cu_stream(),
201 )
202 };
203 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
204 }
205 DType::BF16 => {
206 let output = weight_s
207 .device()
208 .alloc_zeros::<bf16>(weight_l.shape().elem_count())
209 .w()?;
210 unsafe {
211 ffi::launch_dequant_fp8_blockwise_kernel_bf16(
212 (*weight.device_ptr()) as *const _,
213 (*scale.device_ptr()) as *const _,
214 (*output.device_ptr()) as *mut _,
215 weight_height,
216 weight_width,
217 weight_row_stride,
218 scale_stride,
219 weight_block_size_y,
220 weight_block_size_x,
221 *dev.cu_stream(),
222 )
223 };
224 CudaStorage::wrap_cuda_slice(output, weight_s.device().clone())
225 }
226 other => candle_core::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
227 };
228
229 Ok((res, weight_l.shape().clone()))
230 }
231
232 #[cfg(feature = "metal")]
233 fn metal_fwd(
234 &self,
235 scale_s: &candle_core::MetalStorage,
236 scale_l: &candle_core::Layout,
237 weight_s: &candle_core::MetalStorage,
238 weight_l: &candle_core::Layout,
239 ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
240 use candle_core::backend::BackendStorage;
241
242 if weight_l.start_offset() != 0
243 || !weight_l.is_contiguous()
244 || weight_s.dtype() != DType::F8E4M3
245 {
246 candle_core::bail!("Expected f8e4m3 weight to have start offset 0, continuous");
247 }
248 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() || scale_s.dtype() != DType::F32
249 {
250 candle_core::bail!("Expected f32 scales to have start offset 0, continuous");
251 }
252 if weight_l.dims().len() != 2 {
253 candle_core::bail!("Expected weight to be rank 2");
254 }
255 if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
256 candle_core::bail!("Expected scale to be rank 2");
257 }
258
259 let command_buffer = weight_s.device().command_buffer()?;
260 command_buffer.set_label("dequant-blockwise-fp8");
261
262 let device = weight_s.device();
263
264 let out_shape = weight_l.shape().clone();
265
266 let output = device.new_buffer(
267 out_shape.elem_count(),
268 weight_s.dtype(),
269 "dequant-blockwise-fp8",
270 )?;
271
272 let weight_height = weight_l.dim(0)? as u32;
273 let weight_block_size_x = self.weight_block_size[0] as u32;
274 let weight_width = weight_l.dim(1)? as u32;
275 let weight_block_size_y = self.weight_block_size[1] as u32;
276 let scale_stride = scale_l.stride()[0] as u32;
277 let weight_row_stride = weight_l.stride()[0] as u32;
278
279 crate::metal_kernels::call_dequant_blockwise_fp8(
280 device.device(),
281 &command_buffer,
282 &crate::metal_kernels::Kernels::new(),
283 self.out_ty,
284 weight_s.buffer(),
285 scale_s.buffer(),
286 &output,
287 weight_height,
288 weight_width,
289 weight_row_stride,
290 scale_stride,
291 weight_block_size_y,
292 weight_block_size_x,
293 )
294 .map_err(candle_core::Error::wrap)?;
295
296 let newstorage = candle_core::MetalStorage::new(
297 output,
298 device.clone(),
299 out_shape.elem_count(),
300 self.out_ty,
301 );
302 Ok((newstorage, out_shape))
303 }
304}
305
306pub fn fp8_blockwise_dequantize(
311 weight: &Tensor,
312 inv_scales: &Tensor,
313 weight_block_size: Vec<usize>,
314 out_ty: DType,
315) -> Result<Tensor> {
316 inv_scales.apply_op2_no_bwd(
317 weight,
318 &Fp8BlockwiseDequantize {
319 weight_block_size,
320 out_ty,
321 },
322 )
323}
324
325#[cfg(test)]
326#[allow(unused_imports)]
327mod tests {
328 use candle_core::{DType, Device, Result, Tensor};
329 use candle_nn::{Linear, Module};
330 use half::bf16;
331 use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
332
333 use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors};
334
335 #[test]
336 fn test_fp8_blockwise_dequant() -> Result<()> {
337 let dev = &Device::Cpu;
338 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
339 let weight_block_size = vec![2, 2];
340 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
341
342 let dequant =
343 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
344
345 let res = dequant.to_vec2::<f32>()?;
346 assert_eq!(
347 res,
348 vec![
349 vec![0., 0., 1., 1., 2.],
350 vec![0., 0., 1., 1., 2.],
351 vec![3., 3., 4., 4., 5.],
352 vec![3., 3., 4., 4., 5.],
353 vec![6., 6., 7., 7., 8.],
354 ]
355 );
356
357 Ok(())
358 }
359
360 #[cfg(feature = "cuda")]
361 #[test]
362 fn test_fp8_blockwise_dequant_cuda() -> Result<()> {
363 let truth = {
364 let dev = &Device::Cpu;
365 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
366 let weight_block_size = vec![2, 2];
367 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
368
369 let dequant =
370 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
371
372 dequant.to_vec2::<f32>()?
373 };
374 let test = {
375 let dev = &Device::new_cuda(0)?;
376 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
377 let weight_block_size = vec![2, 2];
378 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
379
380 let dequant =
381 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;
382
383 dequant.to_vec2::<f32>()?
384 };
385
386 assert_eq!(test, truth);
387 assert_eq!(
388 test,
389 vec![
390 vec![0., 0., 1., 1., 2.],
391 vec![0., 0., 1., 1., 2.],
392 vec![3., 3., 4., 4., 5.],
393 vec![3., 3., 4., 4., 5.],
394 vec![6., 6., 7., 7., 8.],
395 ]
396 );
397
398 Ok(())
399 }
400
401 #[test]
402 fn test_fp8_blockwise_dequant_bf16() -> Result<()> {
403 let dev = &Device::Cpu;
404 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
405 let weight_block_size = vec![2, 2];
406 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
407
408 let dequant =
409 ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::BF16)?;
410
411 let res = dequant.to_vec2::<bf16>()?;
412 assert_eq!(
413 res,
414 vec![
415 vec![
416 bf16::from_f32(0.),
417 bf16::from_f32(0.),
418 bf16::from_f32(1.),
419 bf16::from_f32(1.),
420 bf16::from_f32(2.)
421 ],
422 vec![
423 bf16::from_f32(0.),
424 bf16::from_f32(0.),
425 bf16::from_f32(1.),
426 bf16::from_f32(1.),
427 bf16::from_f32(2.)
428 ],
429 vec![
430 bf16::from_f32(3.),
431 bf16::from_f32(3.),
432 bf16::from_f32(4.),
433 bf16::from_f32(4.),
434 bf16::from_f32(5.)
435 ],
436 vec![
437 bf16::from_f32(3.),
438 bf16::from_f32(3.),
439 bf16::from_f32(4.),
440 bf16::from_f32(4.),
441 bf16::from_f32(5.)
442 ],
443 vec![
444 bf16::from_f32(6.),
445 bf16::from_f32(6.),
446 bf16::from_f32(7.),
447 bf16::from_f32(7.),
448 bf16::from_f32(8.)
449 ],
450 ]
451 );
452
453 Ok(())
454 }
455
456 #[cfg(feature = "cuda")]
457 #[test]
458 fn test_fp8_blockwise_dequant_cuda_bf16() -> Result<()> {
459 let truth = {
460 let dev = &Device::Cpu;
461 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
462 let weight_block_size = vec![2, 2];
463 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
464
465 let dequant = ops::fp8_blockwise_dequantize(
466 &weight,
467 &inv_scales,
468 weight_block_size,
469 DType::BF16,
470 )?;
471
472 dequant.to_vec2::<bf16>()?
473 };
474 let test = {
475 let dev = &Device::new_cuda(0)?;
476 let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
477 let weight_block_size = vec![2, 2];
478 let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;
479
480 let dequant = ops::fp8_blockwise_dequantize(
481 &weight,
482 &inv_scales,
483 weight_block_size,
484 DType::BF16,
485 )?;
486
487 dequant.to_vec2::<bf16>()?
488 };
489
490 assert_eq!(test, truth);
491 assert_eq!(
492 test,
493 vec![
494 vec![
495 bf16::from_f32(0.),
496 bf16::from_f32(0.),
497 bf16::from_f32(1.),
498 bf16::from_f32(1.),
499 bf16::from_f32(2.)
500 ],
501 vec![
502 bf16::from_f32(0.),
503 bf16::from_f32(0.),
504 bf16::from_f32(1.),
505 bf16::from_f32(1.),
506 bf16::from_f32(2.)
507 ],
508 vec![
509 bf16::from_f32(3.),
510 bf16::from_f32(3.),
511 bf16::from_f32(4.),
512 bf16::from_f32(4.),
513 bf16::from_f32(5.)
514 ],
515 vec![
516 bf16::from_f32(3.),
517 bf16::from_f32(3.),
518 bf16::from_f32(4.),
519 bf16::from_f32(4.),
520 bf16::from_f32(5.)
521 ],
522 vec![
523 bf16::from_f32(6.),
524 bf16::from_f32(6.),
525 bf16::from_f32(7.),
526 bf16::from_f32(7.),
527 bf16::from_f32(8.)
528 ],
529 ]
530 );
531
532 Ok(())
533 }
534
535 #[cfg(feature = "cuda")]
536 #[test]
537 fn test_blockwise_fp8_gemm() -> Result<()> {
538 let dev = Device::cuda_if_available(0)?;
539
540 let api = ApiBuilder::new().with_progress(true).build().unwrap();
541 let api = api.repo(Repo::with_revision(
542 "EricB/mistralrs_tests".to_string(),
543 RepoType::Model,
544 "main".to_string(),
545 ));
546
547 let filename = api.get("test_fp8.safetensors").unwrap();
548 let vb = unsafe { MmapedSafetensors::new(filename)? };
549
550 let weight = vb.load("weight", &dev, None)?;
551 assert_eq!((7168, 2048), weight.dims2()?);
552 assert_eq!(DType::F8E4M3, weight.dtype());
553
554 let scale = vb.load("scale", &dev, None)?;
555 assert_eq!((56, 16), scale.dims2()?);
556 assert_eq!(DType::F32, scale.dtype());
557
558 let weight_block_size = vec![128, 128];
559
560 let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?;
562
563 let truth = {
564 let weight_dq =
565 ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?;
566
567 let lin_dq = Linear::new(weight_dq, None);
568 lin_dq.forward(&xs)?
569 };
570
571 assert_eq!((32, 7168), truth.dims2()?);
573
574 Ok(())
575 }
576}