1use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5use super::VECTOR_SIZE;
6
7struct Fp8VectorDequantize {
8 out_ty: DType,
9}
10
11impl Fp8VectorDequantize {
12 fn dispatch_dequant_vector<T: WithDType>(
13 &self,
14 weight: &[F8E4M3],
15 scale: &[f32],
16 _weight_l: &candle_core::Layout,
17 scale_l: &candle_core::Layout,
18 ) -> candle_core::Result<Vec<T>> {
19 let num_elements = weight.len();
20 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
21
22 if scale.len() != num_vectors {
23 candle_core::bail!(
24 "Scale length {} doesn't match expected number of vectors {}",
25 scale.len(),
26 num_vectors
27 );
28 }
29
30 let res = vec![T::zero(); num_elements];
31
32 (0..num_vectors).into_par_iter().for_each(|vector_idx| {
33 let res_ptr = res.as_ptr() as *mut T;
34 let vector_scale = scale[vector_idx * scale_l.stride()[0]];
35 let vector_start = vector_idx * VECTOR_SIZE;
36 let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
37
38 for (idx, &weight_val) in weight[vector_start..vector_end].iter().enumerate() {
39 let global_idx = vector_start + idx;
40 unsafe {
42 *res_ptr.wrapping_add(global_idx) =
43 T::from_f64((weight_val.to_f32() * vector_scale) as f64);
44 }
45 }
46 });
47
48 Ok(res)
49 }
50}
51
52impl CustomOp2 for Fp8VectorDequantize {
53 fn name(&self) -> &'static str {
54 "fp8-vector-dequantize"
55 }
56
57 fn cpu_fwd(
58 &self,
59 scale_s: &candle_core::CpuStorage,
60 scale_l: &candle_core::Layout,
61 weight_s: &candle_core::CpuStorage,
62 weight_l: &candle_core::Layout,
63 ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
64 let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
65 candle_core::bail!("Expected F8E4M3 weight!");
66 };
67 let candle_core::CpuStorage::F32(scale) = scale_s else {
68 candle_core::bail!("Expected F32 scale!");
69 };
70 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
71 candle_core::bail!("Expected weight to have start offset 0, continuous");
72 }
73 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
74 candle_core::bail!("Expected scales to have start offset 0, continuous");
75 }
76
77 match self.out_ty {
78 DType::F32 => Ok((
79 CpuStorage::F32(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
80 weight_l.shape().clone(),
81 )),
82 DType::BF16 => Ok((
83 CpuStorage::BF16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
84 weight_l.shape().clone(),
85 )),
86 DType::F16 => Ok((
87 CpuStorage::F16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
88 weight_l.shape().clone(),
89 )),
90 other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
91 }
92 }
93
94 #[cfg(feature = "cuda")]
95 fn cuda_fwd(
96 &self,
97 scale_s: &candle_core::CudaStorage,
98 scale_l: &candle_core::Layout,
99 weight_s: &candle_core::CudaStorage,
100 weight_l: &candle_core::Layout,
101 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
102 use candle_core::{backend::BackendStorage, CudaStorage};
103 use half::{bf16, f16};
104
105 use crate::{utils::slice_ptr, vector_fp8::ffi};
106
107 if !ffi::HAVE_VECTOR_DEQUANT_KERNELS {
108 candle_core::bail!("Do not have vector FP8 dequant kernels.");
109 }
110
111 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
112 candle_core::bail!("Expected weight to have start offset 0, continuous");
113 }
114 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
115 candle_core::bail!("Expected scales to have start offset 0, continuous");
116 }
117
118 let dev = weight_s.device();
119 let num_elements = weight_l.shape().elem_count();
120
121 let (weight, _weight_guard) =
122 slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
123 let (scale, _scale_guard) =
124 slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
125
126 let res = match self.out_ty {
127 DType::F32 => {
128 let output = dev.alloc_zeros::<f32>(num_elements)?;
129 let (output_ptr, output_guard) = slice_ptr(&output, 0);
130 unsafe {
131 ffi::launch_dequant_fp8_vector_kernel_f32(
132 weight as *const _,
133 scale as *const _,
134 output_ptr as *mut _,
135 num_elements,
136 dev.cuda_stream().cu_stream(),
137 )
138 };
139 drop(output_guard);
140 CudaStorage::wrap_cuda_slice(output, dev.clone())
141 }
142 DType::F16 => {
143 let output = dev.alloc_zeros::<f16>(num_elements)?;
144 let (output_ptr, output_guard) = slice_ptr(&output, 0);
145 unsafe {
146 ffi::launch_dequant_fp8_vector_kernel_f16(
147 weight as *const _,
148 scale as *const _,
149 output_ptr as *mut _,
150 num_elements,
151 dev.cuda_stream().cu_stream(),
152 )
153 };
154 drop(output_guard);
155 CudaStorage::wrap_cuda_slice(output, dev.clone())
156 }
157 DType::BF16 => {
158 let output = dev.alloc_zeros::<bf16>(num_elements)?;
159 let (output_ptr, output_guard) = slice_ptr(&output, 0);
160 unsafe {
161 ffi::launch_dequant_fp8_vector_kernel_bf16(
162 weight as *const _,
163 scale as *const _,
164 output_ptr as *mut _,
165 num_elements,
166 dev.cuda_stream().cu_stream(),
167 )
168 };
169 drop(output_guard);
170 CudaStorage::wrap_cuda_slice(output, dev.clone())
171 }
172 other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
173 };
174
175 Ok((res, weight_l.shape().clone()))
176 }
177
178 #[cfg(feature = "metal")]
179 fn metal_fwd(
180 &self,
181 _scale_s: &candle_core::MetalStorage,
182 _scale_l: &candle_core::Layout,
183 _weight_s: &candle_core::MetalStorage,
184 _weight_l: &candle_core::Layout,
185 ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
186 candle_core::bail!("FP8 vector dequantization not yet implemented for Metal");
187 }
188}
189
190pub fn fp8_vector_dequantize(
195 weight: &Tensor,
196 inv_scales: &Tensor,
197 out_ty: DType,
198) -> Result<Tensor> {
199 inv_scales.apply_op2_no_bwd(weight, &Fp8VectorDequantize { out_ty })
200}
201
202fn cpu_vector_quantize<T: WithDType>(
204 input: &[T],
205 num_elements: usize,
206) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
207 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
208
209 let weight = vec![F8E4M3::from_f32(0.0); num_elements];
210 let scale = vec![0f32; num_vectors];
211
212 (0..num_vectors).into_par_iter().for_each(|vector_idx| {
213 let weight_ptr = weight.as_ptr() as *mut F8E4M3;
214 let scale_ptr = scale.as_ptr() as *mut f32;
215
216 let vector_start = vector_idx * VECTOR_SIZE;
217 let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
218
219 let mut max_abs = 0f32;
221 for &input_val in &input[vector_start..vector_end] {
222 let val = input_val.to_f64() as f32;
223 let abs_val = val.abs();
224 if abs_val > max_abs {
225 max_abs = abs_val;
226 }
227 }
228
229 let vector_scale = if max_abs > 0.0 {
231 max_abs / 448.0
232 } else {
233 1e-12
234 };
235
236 unsafe {
238 *scale_ptr.wrapping_add(vector_idx) = vector_scale;
239 }
240
241 for (idx, &input_val) in input[vector_start..vector_end].iter().enumerate() {
243 let global_idx = vector_start + idx;
244 let val = input_val.to_f64() as f32;
245 let scaled_val = (val / vector_scale).clamp(-448.0, 448.0);
246
247 unsafe {
249 *weight_ptr.wrapping_add(global_idx) = F8E4M3::from_f32(scaled_val);
250 }
251 }
252 });
253
254 Ok((weight, scale))
255}
256
257fn cpu_fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
259 let num_elements = input.shape().elem_count();
260 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
261
262 let (weight_data, scale_data) = match input.dtype() {
263 DType::F32 => {
264 let data = input.to_vec1::<f32>()?;
265 cpu_vector_quantize(&data, num_elements)?
266 }
267 DType::F16 => {
268 let data = input.to_vec1::<half::f16>()?;
269 cpu_vector_quantize(&data, num_elements)?
270 }
271 DType::BF16 => {
272 let data = input.to_vec1::<half::bf16>()?;
273 cpu_vector_quantize(&data, num_elements)?
274 }
275 other => candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}"),
276 };
277
278 let weight = Tensor::from_vec(weight_data, input.shape(), input.device())?;
280 let scale = Tensor::from_vec(scale_data, num_vectors, input.device())?;
281
282 Ok((weight, scale))
283}
284
285pub fn fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
292 let num_elements = input.shape().elem_count();
294 if num_elements % VECTOR_SIZE != 0 {
295 candle_core::bail!(
296 "Tensor size {} must be divisible by {} for vector FP8 quantization",
297 num_elements,
298 VECTOR_SIZE
299 );
300 }
301
302 if matches!(input.device(), candle_core::Device::Cpu) {
304 return cpu_fp8_vector_quantize(input);
305 }
306
307 #[cfg(feature = "cuda")]
308 {
309 use candle_core::{CudaStorage, Device, Storage};
310 use half::{bf16, f16};
311
312 use crate::{utils::slice_ptr, vector_fp8::ffi};
313
314 if matches!(input.device(), Device::Cuda(_)) {
315 if !ffi::HAVE_VECTOR_QUANT_KERNELS {
316 candle_core::bail!("Do not have vector FP8 quant kernels.");
317 }
318
319 let input_l = input.layout();
320 if input_l.start_offset() != 0 || !input_l.is_contiguous() {
321 candle_core::bail!("Expected input to have start offset 0, continuous");
322 }
323
324 let dev = match input.device() {
325 Device::Cuda(dev) => dev,
326 _ => unreachable!(),
327 };
328
329 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
330
331 let weight_output = dev.alloc_zeros::<F8E4M3>(num_elements)?;
333 let scale_output = dev.alloc_zeros::<f32>(num_vectors)?;
334
335 let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
336 let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
337
338 match input.dtype() {
339 DType::F32 => {
340 let input_storage = input.storage_and_layout().0;
341 let input_s = match &*input_storage {
342 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
343 _ => candle_core::bail!("Expected CUDA storage"),
344 };
345 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
346 unsafe {
347 ffi::launch_quant_fp8_vector_kernel_f32(
348 input_ptr as *const _,
349 weight_ptr as *mut _,
350 scale_ptr as *mut _,
351 num_elements,
352 dev.cuda_stream().cu_stream(),
353 )
354 };
355 }
356 DType::F16 => {
357 let input_storage = input.storage_and_layout().0;
358 let input_s = match &*input_storage {
359 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
360 _ => candle_core::bail!("Expected CUDA storage"),
361 };
362 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
363 unsafe {
364 ffi::launch_quant_fp8_vector_kernel_f16(
365 input_ptr as *const _,
366 weight_ptr as *mut _,
367 scale_ptr as *mut _,
368 num_elements,
369 dev.cuda_stream().cu_stream(),
370 )
371 };
372 }
373 DType::BF16 => {
374 let input_storage = input.storage_and_layout().0;
375 let input_s = match &*input_storage {
376 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
377 _ => candle_core::bail!("Expected CUDA storage"),
378 };
379 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
380 unsafe {
381 ffi::launch_quant_fp8_vector_kernel_bf16(
382 input_ptr as *const _,
383 weight_ptr as *mut _,
384 scale_ptr as *mut _,
385 num_elements,
386 dev.cuda_stream().cu_stream(),
387 )
388 };
389 }
390 other => {
391 candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}")
392 }
393 }
394
395 drop(_weight_guard);
397 drop(_scale_guard);
398
399 let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
401 let weight = Tensor::from((Storage::Cuda(weight_storage), input.shape().clone()));
402
403 let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
405 let scale = Tensor::from((
406 Storage::Cuda(scale_storage),
407 candle_core::Shape::from_dims(&[num_vectors]),
408 ));
409
410 Ok((weight, scale))
411 } else {
412 candle_core::bail!("Expected CUDA device.");
413 }
414 }
415
416 #[cfg(not(feature = "cuda"))]
417 {
418 candle_core::bail!("FP8 vector quantization on non-CPU devices requires CUDA feature");
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use candle_core::{DType, Device, Result, Tensor};
426
427 #[test]
428 fn test_fp8_vector_dequant() -> Result<()> {
429 let dev = &Device::Cpu;
430 let num_elements = 256; let weight = Tensor::ones(num_elements, DType::F8E4M3, dev)?;
432 let scales = Tensor::new(&[2.0f32, 3.0f32], dev)?; let dequant = fp8_vector_dequantize(&weight, &scales, DType::F32)?;
435 let res = dequant.to_vec1::<f32>()?;
436
437 for &val in &res[0..128] {
439 assert_eq!(val, 2.0);
440 }
441 for &val in &res[128..256] {
442 assert_eq!(val, 3.0);
443 }
444
445 Ok(())
446 }
447
448 #[test]
449 fn test_fp8_vector_quant_cpu() -> Result<()> {
450 let dev = &Device::Cpu;
451
452 let input = Tensor::randn(0f32, 2f32, 256, dev)?;
454
455 let (quantized, scales) = fp8_vector_quantize(&input)?;
457
458 assert_eq!(quantized.shape(), input.shape());
460 assert_eq!(scales.dims1()?, 2); let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
464
465 assert_eq!(dequantized.shape(), input.shape());
467
468 let input_vec = input.to_vec1::<f32>()?;
471 let dequant_vec = dequantized.to_vec1::<f32>()?;
472
473 let mut max_error = 0f32;
474 for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
475 let error = (val_in - val_out).abs();
476 max_error = max_error.max(error);
477 }
478
479 assert!(max_error < 0.27, "Max error {max_error} is too large");
481
482 Ok(())
483 }
484
485 #[cfg(feature = "cuda")]
486 #[test]
487 fn test_fp8_vector_quant_dequant_roundtrip() -> Result<()> {
488 let dev = &Device::new_cuda(0)?;
489
490 let input = Tensor::randn(0f32, 2f32, 256, dev)?;
492
493 let (quantized, scales) = fp8_vector_quantize(&input)?;
495
496 assert_eq!(quantized.shape(), input.shape());
498 assert_eq!(scales.dims1()?, 2); let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
502
503 assert_eq!(dequantized.shape(), input.shape());
505
506 let input_vec = input.to_vec1::<f32>()?;
509 let dequant_vec = dequantized.to_vec1::<f32>()?;
510
511 let mut max_error = 0f32;
512 for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
513 let error = (val_in - val_out).abs();
514 max_error = max_error.max(error);
515 }
516
517 assert!(max_error < 0.24, "Max error {} is too large", max_error);
519
520 Ok(())
521 }
522
523 #[cfg(feature = "cuda")]
524 #[test]
525 fn test_fp8_vector_cpu_cuda_equivalence() -> Result<()> {
526 let cpu_dev = &Device::Cpu;
527 let cuda_dev = &Device::new_cuda(0)?;
528
529 let input_data: Vec<f32> = (0..256).map(|i| ((i as f32) - 128.0) / 10.0).collect();
531 let cpu_input = Tensor::from_vec(input_data.clone(), 256, cpu_dev)?;
532 let cuda_input = Tensor::from_vec(input_data, 256, cuda_dev)?;
533
534 let (cpu_quantized, cpu_scales) = fp8_vector_quantize(&cpu_input)?;
536
537 let (cuda_quantized, cuda_scales) = fp8_vector_quantize(&cuda_input)?;
539
540 let cuda_quantized_cpu = cuda_quantized.to_device(cpu_dev)?;
542 let cuda_scales_cpu = cuda_scales.to_device(cpu_dev)?;
543
544 let cpu_quant_vec = cpu_quantized.to_vec1::<F8E4M3>()?;
546 let cuda_quant_vec = cuda_quantized_cpu.to_vec1::<F8E4M3>()?;
547
548 assert_eq!(cpu_quant_vec.len(), cuda_quant_vec.len());
549
550 let mut num_differences = 0;
551 for (i, (cpu_val, cuda_val)) in cpu_quant_vec.iter().zip(cuda_quant_vec.iter()).enumerate()
552 {
553 if cpu_val.to_f32() != cuda_val.to_f32() {
554 let diff = (cpu_val.to_f32() - cuda_val.to_f32()).abs();
556 if diff > 1e-6 {
557 num_differences += 1;
558 if num_differences < 10 {
559 println!(
560 "Difference at index {}: CPU={}, CUDA={}, diff={}",
561 i,
562 cpu_val.to_f32(),
563 cuda_val.to_f32(),
564 diff
565 );
566 }
567 }
568 }
569 }
570
571 assert!(
573 num_differences < 5,
574 "Too many differences between CPU and CUDA quantization: {}",
575 num_differences
576 );
577
578 let cpu_scales_vec = cpu_scales.to_vec1::<f32>()?;
580 let cuda_scales_vec = cuda_scales_cpu.to_vec1::<f32>()?;
581
582 assert_eq!(cpu_scales_vec.len(), cuda_scales_vec.len());
583
584 for (i, (cpu_scale, cuda_scale)) in cpu_scales_vec
585 .iter()
586 .zip(cuda_scales_vec.iter())
587 .enumerate()
588 {
589 let scale_diff = (cpu_scale - cuda_scale).abs();
590 assert!(
591 scale_diff < 1e-6,
592 "Scale difference at index {}: CPU={}, CUDA={}, diff={}",
593 i,
594 cpu_scale,
595 cuda_scale,
596 scale_diff
597 );
598 }
599
600 let cpu_dequant = fp8_vector_dequantize(&cpu_quantized, &cpu_scales, DType::F32)?;
602 let cuda_dequant =
603 fp8_vector_dequantize(&cuda_quantized_cpu, &cuda_scales_cpu, DType::F32)?;
604
605 let cpu_dequant_vec = cpu_dequant.to_vec1::<f32>()?;
606 let cuda_dequant_vec = cuda_dequant.to_vec1::<f32>()?;
607
608 let mut max_dequant_diff = 0f32;
609 for (cpu_val, cuda_val) in cpu_dequant_vec.iter().zip(cuda_dequant_vec.iter()) {
610 let diff = (cpu_val - cuda_val).abs();
611 max_dequant_diff = max_dequant_diff.max(diff);
612 }
613
614 assert!(
615 max_dequant_diff < 1e-5,
616 "Max dequantization difference too large: {}",
617 max_dequant_diff
618 );
619
620 Ok(())
621 }
622}