1#[cfg(feature = "cuda")]
2use candle_core::from_storage_no_op;
3use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
4use float8::F8E4M3;
5use rayon::iter::{IntoParallelIterator, ParallelIterator};
6
7use super::VECTOR_SIZE;
8
9struct Fp8VectorDequantize {
10 out_ty: DType,
11}
12
13impl Fp8VectorDequantize {
14 fn dispatch_dequant_vector<T: WithDType>(
15 &self,
16 weight: &[F8E4M3],
17 scale: &[f32],
18 _weight_l: &candle_core::Layout,
19 scale_l: &candle_core::Layout,
20 ) -> candle_core::Result<Vec<T>> {
21 let num_elements = weight.len();
22 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
23
24 if scale.len() != num_vectors {
25 candle_core::bail!(
26 "Scale length {} doesn't match expected number of vectors {}",
27 scale.len(),
28 num_vectors
29 );
30 }
31
32 let res = vec![T::zero(); num_elements];
33
34 (0..num_vectors).into_par_iter().for_each(|vector_idx| {
35 let res_ptr = res.as_ptr() as *mut T;
36 let vector_scale = scale[vector_idx * scale_l.stride()[0]];
37 let vector_start = vector_idx * VECTOR_SIZE;
38 let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
39
40 for (idx, &weight_val) in weight[vector_start..vector_end].iter().enumerate() {
41 let global_idx = vector_start + idx;
42 unsafe {
44 *res_ptr.wrapping_add(global_idx) =
45 T::from_f64((weight_val.to_f32() * vector_scale) as f64);
46 }
47 }
48 });
49
50 Ok(res)
51 }
52}
53
54impl CustomOp2 for Fp8VectorDequantize {
55 fn name(&self) -> &'static str {
56 "fp8-vector-dequantize"
57 }
58
59 fn cpu_fwd(
60 &self,
61 scale_s: &candle_core::CpuStorage,
62 scale_l: &candle_core::Layout,
63 weight_s: &candle_core::CpuStorage,
64 weight_l: &candle_core::Layout,
65 ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
66 let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
67 candle_core::bail!("Expected F8E4M3 weight!");
68 };
69 let candle_core::CpuStorage::F32(scale) = scale_s else {
70 candle_core::bail!("Expected F32 scale!");
71 };
72 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
73 candle_core::bail!("Expected weight to have start offset 0, continuous");
74 }
75 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
76 candle_core::bail!("Expected scales to have start offset 0, continuous");
77 }
78
79 match self.out_ty {
80 DType::F32 => Ok((
81 CpuStorage::F32(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
82 weight_l.shape().clone(),
83 )),
84 DType::BF16 => Ok((
85 CpuStorage::BF16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
86 weight_l.shape().clone(),
87 )),
88 DType::F16 => Ok((
89 CpuStorage::F16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
90 weight_l.shape().clone(),
91 )),
92 other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
93 }
94 }
95
96 #[cfg(feature = "cuda")]
97 fn cuda_fwd(
98 &self,
99 scale_s: &candle_core::CudaStorage,
100 scale_l: &candle_core::Layout,
101 weight_s: &candle_core::CudaStorage,
102 weight_l: &candle_core::Layout,
103 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
104 use candle_core::{backend::BackendStorage, CudaStorage};
105 use half::{bf16, f16};
106
107 use crate::{utils::slice_ptr, vector_fp8::ffi};
108
109 if !ffi::HAVE_VECTOR_DEQUANT_KERNELS {
110 candle_core::bail!("Do not have vector FP8 dequant kernels.");
111 }
112
113 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
114 candle_core::bail!("Expected weight to have start offset 0, continuous");
115 }
116 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
117 candle_core::bail!("Expected scales to have start offset 0, continuous");
118 }
119
120 let dev = weight_s.device();
121 let num_elements = weight_l.shape().elem_count();
122
123 let (weight, _weight_guard) =
124 slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
125 let (scale, _scale_guard) =
126 slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
127
128 let res = match self.out_ty {
129 DType::F32 => {
130 let output = dev.alloc_zeros::<f32>(num_elements)?;
131 let (output_ptr, output_guard) = slice_ptr(&output, 0);
132 unsafe {
133 ffi::launch_dequant_fp8_vector_kernel_f32(
134 weight as *const _,
135 scale as *const _,
136 output_ptr as *mut _,
137 num_elements,
138 dev.cuda_stream().cu_stream(),
139 )
140 };
141 drop(output_guard);
142 CudaStorage::wrap_cuda_slice(output, dev.clone())
143 }
144 DType::F16 => {
145 let output = dev.alloc_zeros::<f16>(num_elements)?;
146 let (output_ptr, output_guard) = slice_ptr(&output, 0);
147 unsafe {
148 ffi::launch_dequant_fp8_vector_kernel_f16(
149 weight as *const _,
150 scale as *const _,
151 output_ptr as *mut _,
152 num_elements,
153 dev.cuda_stream().cu_stream(),
154 )
155 };
156 drop(output_guard);
157 CudaStorage::wrap_cuda_slice(output, dev.clone())
158 }
159 DType::BF16 => {
160 let output = dev.alloc_zeros::<bf16>(num_elements)?;
161 let (output_ptr, output_guard) = slice_ptr(&output, 0);
162 unsafe {
163 ffi::launch_dequant_fp8_vector_kernel_bf16(
164 weight as *const _,
165 scale as *const _,
166 output_ptr as *mut _,
167 num_elements,
168 dev.cuda_stream().cu_stream(),
169 )
170 };
171 drop(output_guard);
172 CudaStorage::wrap_cuda_slice(output, dev.clone())
173 }
174 other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
175 };
176
177 Ok((res, weight_l.shape().clone()))
178 }
179
180 #[cfg(feature = "metal")]
181 fn metal_fwd(
182 &self,
183 _scale_s: &candle_core::MetalStorage,
184 _scale_l: &candle_core::Layout,
185 _weight_s: &candle_core::MetalStorage,
186 _weight_l: &candle_core::Layout,
187 ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
188 candle_core::bail!("FP8 vector dequantization not yet implemented for Metal");
189 }
190}
191
192pub fn fp8_vector_dequantize(
197 weight: &Tensor,
198 inv_scales: &Tensor,
199 out_ty: DType,
200) -> Result<Tensor> {
201 inv_scales.apply_op2_no_bwd(weight, &Fp8VectorDequantize { out_ty })
202}
203
204fn cpu_vector_quantize<T: WithDType>(
206 input: &[T],
207 num_elements: usize,
208) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
209 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
210
211 let weight = vec![F8E4M3::from_f32(0.0); num_elements];
212 let scale = vec![0f32; num_vectors];
213
214 (0..num_vectors).into_par_iter().for_each(|vector_idx| {
215 let weight_ptr = weight.as_ptr() as *mut F8E4M3;
216 let scale_ptr = scale.as_ptr() as *mut f32;
217
218 let vector_start = vector_idx * VECTOR_SIZE;
219 let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
220
221 let mut max_abs = 0f32;
223 for &input_val in &input[vector_start..vector_end] {
224 let val = input_val.to_f64() as f32;
225 let abs_val = val.abs();
226 if abs_val > max_abs {
227 max_abs = abs_val;
228 }
229 }
230
231 let vector_scale = if max_abs > 0.0 {
233 max_abs / 448.0
234 } else {
235 1e-12
236 };
237
238 unsafe {
240 *scale_ptr.wrapping_add(vector_idx) = vector_scale;
241 }
242
243 for (idx, &input_val) in input[vector_start..vector_end].iter().enumerate() {
245 let global_idx = vector_start + idx;
246 let val = input_val.to_f64() as f32;
247 let scaled_val = (val / vector_scale).clamp(-448.0, 448.0);
248
249 unsafe {
251 *weight_ptr.wrapping_add(global_idx) = F8E4M3::from_f32(scaled_val);
252 }
253 }
254 });
255
256 Ok((weight, scale))
257}
258
259fn cpu_fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
261 let num_elements = input.shape().elem_count();
262 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
263
264 let (weight_data, scale_data) = match input.dtype() {
265 DType::F32 => {
266 let data = input.to_vec1::<f32>()?;
267 cpu_vector_quantize(&data, num_elements)?
268 }
269 DType::F16 => {
270 let data = input.to_vec1::<half::f16>()?;
271 cpu_vector_quantize(&data, num_elements)?
272 }
273 DType::BF16 => {
274 let data = input.to_vec1::<half::bf16>()?;
275 cpu_vector_quantize(&data, num_elements)?
276 }
277 other => candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}"),
278 };
279
280 let weight = Tensor::from_vec(weight_data, input.shape(), input.device())?;
282 let scale = Tensor::from_vec(scale_data, num_vectors, input.device())?;
283
284 Ok((weight, scale))
285}
286
287pub fn fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
294 let num_elements = input.shape().elem_count();
296 if num_elements % VECTOR_SIZE != 0 {
297 candle_core::bail!(
298 "Tensor size {} must be divisible by {} for vector FP8 quantization",
299 num_elements,
300 VECTOR_SIZE
301 );
302 }
303
304 if matches!(input.device(), candle_core::Device::Cpu) {
306 return cpu_fp8_vector_quantize(input);
307 }
308
309 #[cfg(feature = "cuda")]
310 {
311 use candle_core::{CudaStorage, Device, Storage};
312 use half::{bf16, f16};
313
314 use crate::{utils::slice_ptr, vector_fp8::ffi};
315
316 if matches!(input.device(), Device::Cuda(_)) {
317 if !ffi::HAVE_VECTOR_QUANT_KERNELS {
318 candle_core::bail!("Do not have vector FP8 quant kernels.");
319 }
320
321 let input_l = input.layout();
322 if input_l.start_offset() != 0 || !input_l.is_contiguous() {
323 candle_core::bail!("Expected input to have start offset 0, continuous");
324 }
325
326 let dev = match input.device() {
327 Device::Cuda(dev) => dev,
328 _ => unreachable!(),
329 };
330
331 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
332
333 let weight_output = dev.alloc_zeros::<F8E4M3>(num_elements)?;
335 let scale_output = dev.alloc_zeros::<f32>(num_vectors)?;
336
337 let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
338 let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
339
340 match input.dtype() {
341 DType::F32 => {
342 let input_storage = input.storage_and_layout().0;
343 let input_s = match &*input_storage {
344 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
345 _ => candle_core::bail!("Expected CUDA storage"),
346 };
347 let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
348 unsafe {
349 ffi::launch_quant_fp8_vector_kernel_f32(
350 input_ptr as *const _,
351 weight_ptr as *mut _,
352 scale_ptr as *mut _,
353 num_elements,
354 dev.cuda_stream().cu_stream(),
355 )
356 };
357 }
358 DType::F16 => {
359 let input_storage = input.storage_and_layout().0;
360 let input_s = match &*input_storage {
361 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
362 _ => candle_core::bail!("Expected CUDA storage"),
363 };
364 let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
365 unsafe {
366 ffi::launch_quant_fp8_vector_kernel_f16(
367 input_ptr as *const _,
368 weight_ptr as *mut _,
369 scale_ptr as *mut _,
370 num_elements,
371 dev.cuda_stream().cu_stream(),
372 )
373 };
374 }
375 DType::BF16 => {
376 let input_storage = input.storage_and_layout().0;
377 let input_s = match &*input_storage {
378 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
379 _ => candle_core::bail!("Expected CUDA storage"),
380 };
381 let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
382 unsafe {
383 ffi::launch_quant_fp8_vector_kernel_bf16(
384 input_ptr as *const _,
385 weight_ptr as *mut _,
386 scale_ptr as *mut _,
387 num_elements,
388 dev.cuda_stream().cu_stream(),
389 )
390 };
391 }
392 other => {
393 candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}")
394 }
395 }
396
397 drop(_weight_guard);
399 drop(_scale_guard);
400
401 let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
403 let weight =
404 from_storage_no_op(Storage::Cuda(weight_storage), input.shape().clone(), false);
405
406 let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
408 let scale = from_storage_no_op(
409 Storage::Cuda(scale_storage),
410 candle_core::Shape::from_dims(&[num_vectors]),
411 false,
412 );
413
414 return Ok((weight, scale));
415 } else {
416 candle_core::bail!("Expected CUDA device.");
417 }
418 }
419
420 #[cfg(not(feature = "cuda"))]
421 {
422 candle_core::bail!("FP8 vector quantization on non-CPU devices requires CUDA feature");
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use candle_core::{DType, Device, Result, Tensor};
430
431 #[test]
432 fn test_fp8_vector_dequant() -> Result<()> {
433 let dev = &Device::Cpu;
434 let num_elements = 256; let weight = Tensor::ones(num_elements, DType::F8E4M3, dev)?;
436 let scales = Tensor::new(&[2.0f32, 3.0f32], dev)?; let dequant = fp8_vector_dequantize(&weight, &scales, DType::F32)?;
439 let res = dequant.to_vec1::<f32>()?;
440
441 for &val in &res[0..128] {
443 assert_eq!(val, 2.0);
444 }
445 for &val in &res[128..256] {
446 assert_eq!(val, 3.0);
447 }
448
449 Ok(())
450 }
451
452 #[test]
453 fn test_fp8_vector_quant_cpu() -> Result<()> {
454 let dev = &Device::Cpu;
455
456 let input = Tensor::randn(0f32, 2f32, 256, dev)?;
458
459 let (quantized, scales) = fp8_vector_quantize(&input)?;
461
462 assert_eq!(quantized.shape(), input.shape());
464 assert_eq!(scales.dims1()?, 2); let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
468
469 assert_eq!(dequantized.shape(), input.shape());
471
472 let input_vec = input.to_vec1::<f32>()?;
475 let dequant_vec = dequantized.to_vec1::<f32>()?;
476
477 let mut max_error = 0f32;
478 for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
479 let error = (val_in - val_out).abs();
480 max_error = max_error.max(error);
481 }
482
483 assert!(max_error < 0.25, "Max error {max_error} is too large");
485
486 Ok(())
487 }
488
489 #[cfg(feature = "cuda")]
490 #[test]
491 fn test_fp8_vector_quant_dequant_roundtrip() -> Result<()> {
492 let dev = &Device::new_cuda(0)?;
493
494 let input = Tensor::randn(0f32, 2f32, 256, dev)?;
496
497 let (quantized, scales) = fp8_vector_quantize(&input)?;
499
500 assert_eq!(quantized.shape(), input.shape());
502 assert_eq!(scales.dims1()?, 2); let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
506
507 assert_eq!(dequantized.shape(), input.shape());
509
510 let input_vec = input.to_vec1::<f32>()?;
513 let dequant_vec = dequantized.to_vec1::<f32>()?;
514
515 let mut max_error = 0f32;
516 for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
517 let error = (val_in - val_out).abs();
518 max_error = max_error.max(error);
519 }
520
521 assert!(max_error < 0.24, "Max error {} is too large", max_error);
523
524 Ok(())
525 }
526
527 #[cfg(feature = "cuda")]
528 #[test]
529 fn test_fp8_vector_cpu_cuda_equivalence() -> Result<()> {
530 let cpu_dev = &Device::Cpu;
531 let cuda_dev = &Device::new_cuda(0)?;
532
533 let input_data: Vec<f32> = (0..256).map(|i| ((i as f32) - 128.0) / 10.0).collect();
535 let cpu_input = Tensor::from_vec(input_data.clone(), 256, cpu_dev)?;
536 let cuda_input = Tensor::from_vec(input_data, 256, cuda_dev)?;
537
538 let (cpu_quantized, cpu_scales) = fp8_vector_quantize(&cpu_input)?;
540
541 let (cuda_quantized, cuda_scales) = fp8_vector_quantize(&cuda_input)?;
543
544 let cuda_quantized_cpu = cuda_quantized.to_device(cpu_dev)?;
546 let cuda_scales_cpu = cuda_scales.to_device(cpu_dev)?;
547
548 let cpu_quant_vec = cpu_quantized.to_vec1::<F8E4M3>()?;
550 let cuda_quant_vec = cuda_quantized_cpu.to_vec1::<F8E4M3>()?;
551
552 assert_eq!(cpu_quant_vec.len(), cuda_quant_vec.len());
553
554 let mut num_differences = 0;
555 for (i, (cpu_val, cuda_val)) in cpu_quant_vec.iter().zip(cuda_quant_vec.iter()).enumerate()
556 {
557 if cpu_val.to_f32() != cuda_val.to_f32() {
558 let diff = (cpu_val.to_f32() - cuda_val.to_f32()).abs();
560 if diff > 1e-6 {
561 num_differences += 1;
562 if num_differences < 10 {
563 println!(
564 "Difference at index {}: CPU={}, CUDA={}, diff={}",
565 i,
566 cpu_val.to_f32(),
567 cuda_val.to_f32(),
568 diff
569 );
570 }
571 }
572 }
573 }
574
575 assert!(
577 num_differences < 5,
578 "Too many differences between CPU and CUDA quantization: {}",
579 num_differences
580 );
581
582 let cpu_scales_vec = cpu_scales.to_vec1::<f32>()?;
584 let cuda_scales_vec = cuda_scales_cpu.to_vec1::<f32>()?;
585
586 assert_eq!(cpu_scales_vec.len(), cuda_scales_vec.len());
587
588 for (i, (cpu_scale, cuda_scale)) in cpu_scales_vec
589 .iter()
590 .zip(cuda_scales_vec.iter())
591 .enumerate()
592 {
593 let scale_diff = (cpu_scale - cuda_scale).abs();
594 assert!(
595 scale_diff < 1e-6,
596 "Scale difference at index {}: CPU={}, CUDA={}, diff={}",
597 i,
598 cpu_scale,
599 cuda_scale,
600 scale_diff
601 );
602 }
603
604 let cpu_dequant = fp8_vector_dequantize(&cpu_quantized, &cpu_scales, DType::F32)?;
606 let cuda_dequant =
607 fp8_vector_dequantize(&cuda_quantized_cpu, &cuda_scales_cpu, DType::F32)?;
608
609 let cpu_dequant_vec = cpu_dequant.to_vec1::<f32>()?;
610 let cuda_dequant_vec = cuda_dequant.to_vec1::<f32>()?;
611
612 let mut max_dequant_diff = 0f32;
613 for (cpu_val, cuda_val) in cpu_dequant_vec.iter().zip(cuda_dequant_vec.iter()) {
614 let diff = (cpu_val - cuda_val).abs();
615 max_dequant_diff = max_dequant_diff.max(diff);
616 }
617
618 assert!(
619 max_dequant_diff < 1e-5,
620 "Max dequantization difference too large: {}",
621 max_dequant_diff
622 );
623
624 Ok(())
625 }
626}