1use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5use super::VECTOR_SIZE;
6
7struct Fp8VectorDequantize {
8 out_ty: DType,
9}
10
11impl Fp8VectorDequantize {
12 fn dispatch_dequant_vector<T: WithDType>(
13 &self,
14 weight: &[F8E4M3],
15 scale: &[f32],
16 _weight_l: &candle_core::Layout,
17 scale_l: &candle_core::Layout,
18 ) -> candle_core::Result<Vec<T>> {
19 let num_elements = weight.len();
20 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
21
22 if scale.len() != num_vectors {
23 candle_core::bail!(
24 "Scale length {} doesn't match expected number of vectors {}",
25 scale.len(),
26 num_vectors
27 );
28 }
29
30 let res = vec![T::zero(); num_elements];
31
32 (0..num_vectors).into_par_iter().for_each(|vector_idx| {
33 let res_ptr = res.as_ptr() as *mut T;
34 let vector_scale = scale[vector_idx * scale_l.stride()[0]];
35 let vector_start = vector_idx * VECTOR_SIZE;
36 let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
37
38 for (idx, &weight_val) in weight[vector_start..vector_end].iter().enumerate() {
39 let global_idx = vector_start + idx;
40 unsafe {
42 *res_ptr.wrapping_add(global_idx) =
43 T::from_f64((weight_val.to_f32() * vector_scale) as f64);
44 }
45 }
46 });
47
48 Ok(res)
49 }
50}
51
52impl CustomOp2 for Fp8VectorDequantize {
53 fn name(&self) -> &'static str {
54 "fp8-vector-dequantize"
55 }
56
57 fn cpu_fwd(
58 &self,
59 scale_s: &candle_core::CpuStorage,
60 scale_l: &candle_core::Layout,
61 weight_s: &candle_core::CpuStorage,
62 weight_l: &candle_core::Layout,
63 ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
64 let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
65 candle_core::bail!("Expected F8E4M3 weight!");
66 };
67 let candle_core::CpuStorage::F32(scale) = scale_s else {
68 candle_core::bail!("Expected F32 scale!");
69 };
70 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
71 candle_core::bail!("Expected weight to have start offset 0, continuous");
72 }
73 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
74 candle_core::bail!("Expected scales to have start offset 0, continuous");
75 }
76
77 match self.out_ty {
78 DType::F32 => Ok((
79 CpuStorage::F32(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
80 weight_l.shape().clone(),
81 )),
82 DType::BF16 => Ok((
83 CpuStorage::BF16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
84 weight_l.shape().clone(),
85 )),
86 DType::F16 => Ok((
87 CpuStorage::F16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
88 weight_l.shape().clone(),
89 )),
90 other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
91 }
92 }
93
94 #[cfg(feature = "cuda")]
95 fn cuda_fwd(
96 &self,
97 scale_s: &candle_core::CudaStorage,
98 scale_l: &candle_core::Layout,
99 weight_s: &candle_core::CudaStorage,
100 weight_l: &candle_core::Layout,
101 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
102 use candle_core::{backend::BackendStorage, CudaStorage};
103 use half::{bf16, f16};
104
105 use crate::{utils::slice_ptr, vector_fp8::ffi};
106
107 if !ffi::HAVE_VECTOR_DEQUANT_KERNELS {
108 candle_core::bail!("Do not have vector FP8 dequant kernels.");
109 }
110
111 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
112 candle_core::bail!("Expected weight to have start offset 0, continuous");
113 }
114 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
115 candle_core::bail!("Expected scales to have start offset 0, continuous");
116 }
117
118 let dev = weight_s.device();
119 let num_elements = weight_l.shape().elem_count();
120
121 let (weight, _weight_guard) =
122 slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
123 let (scale, _scale_guard) =
124 slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
125
126 let res = match self.out_ty {
127 DType::F32 => {
128 let output = dev.alloc_zeros::<f32>(num_elements)?;
129 let (output_ptr, output_guard) = slice_ptr(&output, 0);
130 unsafe {
131 ffi::launch_dequant_fp8_vector_kernel_f32(
132 weight as *const _,
133 scale as *const _,
134 output_ptr as *mut _,
135 num_elements,
136 dev.cuda_stream().cu_stream(),
137 )
138 };
139 drop(output_guard);
140 CudaStorage::wrap_cuda_slice(output, dev.clone())
141 }
142 DType::F16 => {
143 let output = dev.alloc_zeros::<f16>(num_elements)?;
144 let (output_ptr, output_guard) = slice_ptr(&output, 0);
145 unsafe {
146 ffi::launch_dequant_fp8_vector_kernel_f16(
147 weight as *const _,
148 scale as *const _,
149 output_ptr as *mut _,
150 num_elements,
151 dev.cuda_stream().cu_stream(),
152 )
153 };
154 drop(output_guard);
155 CudaStorage::wrap_cuda_slice(output, dev.clone())
156 }
157 DType::BF16 => {
158 let output = dev.alloc_zeros::<bf16>(num_elements)?;
159 let (output_ptr, output_guard) = slice_ptr(&output, 0);
160 unsafe {
161 ffi::launch_dequant_fp8_vector_kernel_bf16(
162 weight as *const _,
163 scale as *const _,
164 output_ptr as *mut _,
165 num_elements,
166 dev.cuda_stream().cu_stream(),
167 )
168 };
169 drop(output_guard);
170 CudaStorage::wrap_cuda_slice(output, dev.clone())
171 }
172 other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
173 };
174
175 Ok((res, weight_l.shape().clone()))
176 }
177
178 #[cfg(feature = "metal")]
179 fn metal_fwd(
180 &self,
181 scale_s: &candle_core::MetalStorage,
182 scale_l: &candle_core::Layout,
183 weight_s: &candle_core::MetalStorage,
184 weight_l: &candle_core::Layout,
185 ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
186 use candle_core::backend::BackendStorage;
187
188 if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
189 candle_core::bail!("Expected weight to have start offset 0, continuous");
190 }
191 if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
192 candle_core::bail!("Expected scales to have start offset 0, continuous");
193 }
194
195 let device = weight_s.device();
196 let encoder = device.command_encoder()?;
197 encoder.set_label("fp8-vector-dequant");
198
199 let num_elements = weight_l.shape().elem_count();
200 let out_shape = weight_l.shape().clone();
201
202 let output = device.new_buffer(num_elements, self.out_ty, "fp8-vector-dequant-output")?;
203
204 crate::metal_kernels::call_fp8_vector_dequant(
205 device.device(),
206 &encoder,
207 &crate::metal_kernels::Kernels::new(),
208 self.out_ty,
209 weight_s.buffer(),
210 scale_s.buffer(),
211 &output,
212 num_elements,
213 )
214 .map_err(candle_core::Error::wrap)?;
215
216 let newstorage =
217 candle_core::MetalStorage::new(output, device.clone(), num_elements, self.out_ty);
218 Ok((newstorage, out_shape))
219 }
220}
221
222pub fn fp8_vector_dequantize(
227 weight: &Tensor,
228 inv_scales: &Tensor,
229 out_ty: DType,
230) -> Result<Tensor> {
231 inv_scales.apply_op2_no_bwd(weight, &Fp8VectorDequantize { out_ty })
232}
233
234fn cpu_vector_quantize<T: WithDType>(
236 input: &[T],
237 num_elements: usize,
238) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
239 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
240
241 let weight = vec![F8E4M3::from_f32(0.0); num_elements];
242 let scale = vec![0f32; num_vectors];
243
244 (0..num_vectors).into_par_iter().for_each(|vector_idx| {
245 let weight_ptr = weight.as_ptr() as *mut F8E4M3;
246 let scale_ptr = scale.as_ptr() as *mut f32;
247
248 let vector_start = vector_idx * VECTOR_SIZE;
249 let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
250
251 let mut max_abs = 0f32;
253 for &input_val in &input[vector_start..vector_end] {
254 let val = input_val.to_f64() as f32;
255 let abs_val = val.abs();
256 if abs_val > max_abs {
257 max_abs = abs_val;
258 }
259 }
260
261 let vector_scale = if max_abs > 0.0 {
263 max_abs / 448.0
264 } else {
265 1e-12
266 };
267
268 unsafe {
270 *scale_ptr.wrapping_add(vector_idx) = vector_scale;
271 }
272
273 for (idx, &input_val) in input[vector_start..vector_end].iter().enumerate() {
275 let global_idx = vector_start + idx;
276 let val = input_val.to_f64() as f32;
277 let scaled_val = (val / vector_scale).clamp(-448.0, 448.0);
278
279 unsafe {
281 *weight_ptr.wrapping_add(global_idx) = F8E4M3::from_f32(scaled_val);
282 }
283 }
284 });
285
286 Ok((weight, scale))
287}
288
289fn cpu_fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
291 let num_elements = input.shape().elem_count();
292 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
293
294 let (weight_data, scale_data) = match input.dtype() {
295 DType::F32 => {
296 let data = input.to_vec1::<f32>()?;
297 cpu_vector_quantize(&data, num_elements)?
298 }
299 DType::F16 => {
300 let data = input.to_vec1::<half::f16>()?;
301 cpu_vector_quantize(&data, num_elements)?
302 }
303 DType::BF16 => {
304 let data = input.to_vec1::<half::bf16>()?;
305 cpu_vector_quantize(&data, num_elements)?
306 }
307 other => candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}"),
308 };
309
310 let weight = Tensor::from_vec(weight_data, input.shape(), input.device())?;
312 let scale = Tensor::from_vec(scale_data, num_vectors, input.device())?;
313
314 Ok((weight, scale))
315}
316
317pub fn fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
324 let num_elements = input.shape().elem_count();
326 if num_elements % VECTOR_SIZE != 0 {
327 candle_core::bail!(
328 "Tensor size {} must be divisible by {} for vector FP8 quantization",
329 num_elements,
330 VECTOR_SIZE
331 );
332 }
333
334 if matches!(input.device(), candle_core::Device::Cpu) {
336 return cpu_fp8_vector_quantize(input);
337 }
338
339 #[cfg(feature = "cuda")]
340 {
341 use candle_core::{CudaStorage, Device, Storage};
342 use half::{bf16, f16};
343
344 use crate::{utils::slice_ptr, vector_fp8::ffi};
345
346 if matches!(input.device(), Device::Cuda(_)) {
347 if !ffi::HAVE_VECTOR_QUANT_KERNELS {
348 candle_core::bail!("Do not have vector FP8 quant kernels.");
349 }
350
351 let input_l = input.layout();
352 if input_l.start_offset() != 0 || !input_l.is_contiguous() {
353 candle_core::bail!("Expected input to have start offset 0, continuous");
354 }
355
356 let dev = match input.device() {
357 Device::Cuda(dev) => dev,
358 _ => unreachable!(),
359 };
360
361 let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
362
363 let weight_output = dev.alloc_zeros::<F8E4M3>(num_elements)?;
365 let scale_output = dev.alloc_zeros::<f32>(num_vectors)?;
366
367 let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
368 let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
369
370 match input.dtype() {
371 DType::F32 => {
372 let input_storage = input.storage_and_layout().0;
373 let input_s = match &*input_storage {
374 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
375 _ => candle_core::bail!("Expected CUDA storage"),
376 };
377 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
378 unsafe {
379 ffi::launch_quant_fp8_vector_kernel_f32(
380 input_ptr as *const _,
381 weight_ptr as *mut _,
382 scale_ptr as *mut _,
383 num_elements,
384 dev.cuda_stream().cu_stream(),
385 )
386 };
387 }
388 DType::F16 => {
389 let input_storage = input.storage_and_layout().0;
390 let input_s = match &*input_storage {
391 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
392 _ => candle_core::bail!("Expected CUDA storage"),
393 };
394 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
395 unsafe {
396 ffi::launch_quant_fp8_vector_kernel_f16(
397 input_ptr as *const _,
398 weight_ptr as *mut _,
399 scale_ptr as *mut _,
400 num_elements,
401 dev.cuda_stream().cu_stream(),
402 )
403 };
404 }
405 DType::BF16 => {
406 let input_storage = input.storage_and_layout().0;
407 let input_s = match &*input_storage {
408 Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
409 _ => candle_core::bail!("Expected CUDA storage"),
410 };
411 let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
412 unsafe {
413 ffi::launch_quant_fp8_vector_kernel_bf16(
414 input_ptr as *const _,
415 weight_ptr as *mut _,
416 scale_ptr as *mut _,
417 num_elements,
418 dev.cuda_stream().cu_stream(),
419 )
420 };
421 }
422 other => {
423 candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}")
424 }
425 }
426
427 drop(_weight_guard);
429 drop(_scale_guard);
430
431 let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
433 let weight = Tensor::from((Storage::Cuda(weight_storage), input.shape().clone()));
434
435 let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
437 let scale = Tensor::from((
438 Storage::Cuda(scale_storage),
439 candle_core::Shape::from_dims(&[num_vectors]),
440 ));
441
442 Ok((weight, scale))
443 } else {
444 candle_core::bail!("Expected CUDA device.");
445 }
446 }
447
448 #[cfg(not(feature = "cuda"))]
449 {
450 candle_core::bail!("FP8 vector quantization on non-CPU devices requires CUDA feature");
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use candle_core::{DType, Device, Result, Tensor};
458
459 #[test]
460 fn test_fp8_vector_dequant() -> Result<()> {
461 let dev = &Device::Cpu;
462 let num_elements = 256; let weight = Tensor::ones(num_elements, DType::F8E4M3, dev)?;
464 let scales = Tensor::new(&[2.0f32, 3.0f32], dev)?; let dequant = fp8_vector_dequantize(&weight, &scales, DType::F32)?;
467 let res = dequant.to_vec1::<f32>()?;
468
469 for &val in &res[0..128] {
471 assert_eq!(val, 2.0);
472 }
473 for &val in &res[128..256] {
474 assert_eq!(val, 3.0);
475 }
476
477 Ok(())
478 }
479
480 #[test]
481 fn test_fp8_vector_quant_cpu() -> Result<()> {
482 let dev = &Device::Cpu;
483
484 let input = Tensor::randn(0f32, 2f32, 256, dev)?;
486
487 let (quantized, scales) = fp8_vector_quantize(&input)?;
489
490 assert_eq!(quantized.shape(), input.shape());
492 assert_eq!(scales.dims1()?, 2); let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
496
497 assert_eq!(dequantized.shape(), input.shape());
499
500 let input_vec = input.to_vec1::<f32>()?;
503 let dequant_vec = dequantized.to_vec1::<f32>()?;
504
505 let mut max_error = 0f32;
506 for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
507 let error = (val_in - val_out).abs();
508 max_error = max_error.max(error);
509 }
510
511 assert!(max_error < 0.27, "Max error {max_error} is too large");
513
514 Ok(())
515 }
516
517 #[cfg(feature = "cuda")]
518 #[test]
519 fn test_fp8_vector_quant_dequant_roundtrip() -> Result<()> {
520 let dev = &Device::new_cuda(0)?;
521
522 let input = Tensor::randn(0f32, 2f32, 256, dev)?;
524
525 let (quantized, scales) = fp8_vector_quantize(&input)?;
527
528 assert_eq!(quantized.shape(), input.shape());
530 assert_eq!(scales.dims1()?, 2); let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
534
535 assert_eq!(dequantized.shape(), input.shape());
537
538 let input_vec = input.to_vec1::<f32>()?;
541 let dequant_vec = dequantized.to_vec1::<f32>()?;
542
543 let mut max_error = 0f32;
544 for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
545 let error = (val_in - val_out).abs();
546 max_error = max_error.max(error);
547 }
548
549 assert!(max_error < 0.24, "Max error {} is too large", max_error);
551
552 Ok(())
553 }
554
555 #[cfg(feature = "cuda")]
556 #[test]
557 fn test_fp8_vector_cpu_cuda_equivalence() -> Result<()> {
558 let cpu_dev = &Device::Cpu;
559 let cuda_dev = &Device::new_cuda(0)?;
560
561 let input_data: Vec<f32> = (0..256).map(|i| ((i as f32) - 128.0) / 10.0).collect();
563 let cpu_input = Tensor::from_vec(input_data.clone(), 256, cpu_dev)?;
564 let cuda_input = Tensor::from_vec(input_data, 256, cuda_dev)?;
565
566 let (cpu_quantized, cpu_scales) = fp8_vector_quantize(&cpu_input)?;
568
569 let (cuda_quantized, cuda_scales) = fp8_vector_quantize(&cuda_input)?;
571
572 let cuda_quantized_cpu = cuda_quantized.to_device(cpu_dev)?;
574 let cuda_scales_cpu = cuda_scales.to_device(cpu_dev)?;
575
576 let cpu_quant_vec = cpu_quantized.to_vec1::<F8E4M3>()?;
578 let cuda_quant_vec = cuda_quantized_cpu.to_vec1::<F8E4M3>()?;
579
580 assert_eq!(cpu_quant_vec.len(), cuda_quant_vec.len());
581
582 let mut num_differences = 0;
583 for (i, (cpu_val, cuda_val)) in cpu_quant_vec.iter().zip(cuda_quant_vec.iter()).enumerate()
584 {
585 if cpu_val.to_f32() != cuda_val.to_f32() {
586 let diff = (cpu_val.to_f32() - cuda_val.to_f32()).abs();
588 if diff > 1e-6 {
589 num_differences += 1;
590 if num_differences < 10 {
591 println!(
592 "Difference at index {}: CPU={}, CUDA={}, diff={}",
593 i,
594 cpu_val.to_f32(),
595 cuda_val.to_f32(),
596 diff
597 );
598 }
599 }
600 }
601 }
602
603 assert!(
605 num_differences < 5,
606 "Too many differences between CPU and CUDA quantization: {}",
607 num_differences
608 );
609
610 let cpu_scales_vec = cpu_scales.to_vec1::<f32>()?;
612 let cuda_scales_vec = cuda_scales_cpu.to_vec1::<f32>()?;
613
614 assert_eq!(cpu_scales_vec.len(), cuda_scales_vec.len());
615
616 for (i, (cpu_scale, cuda_scale)) in cpu_scales_vec
617 .iter()
618 .zip(cuda_scales_vec.iter())
619 .enumerate()
620 {
621 let scale_diff = (cpu_scale - cuda_scale).abs();
622 assert!(
623 scale_diff < 1e-6,
624 "Scale difference at index {}: CPU={}, CUDA={}, diff={}",
625 i,
626 cpu_scale,
627 cuda_scale,
628 scale_diff
629 );
630 }
631
632 let cpu_dequant = fp8_vector_dequantize(&cpu_quantized, &cpu_scales, DType::F32)?;
634 let cuda_dequant =
635 fp8_vector_dequantize(&cuda_quantized_cpu, &cuda_scales_cpu, DType::F32)?;
636
637 let cpu_dequant_vec = cpu_dequant.to_vec1::<f32>()?;
638 let cuda_dequant_vec = cuda_dequant.to_vec1::<f32>()?;
639
640 let mut max_dequant_diff = 0f32;
641 for (cpu_val, cuda_val) in cpu_dequant_vec.iter().zip(cuda_dequant_vec.iter()) {
642 let diff = (cpu_val - cuda_val).abs();
643 max_dequant_diff = max_dequant_diff.max(diff);
644 }
645
646 assert!(
647 max_dequant_diff < 1e-5,
648 "Max dequantization difference too large: {}",
649 max_dequant_diff
650 );
651
652 Ok(())
653 }
654}