mistralrs_quant/vector_fp8/
ops.rs

1use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5use super::VECTOR_SIZE;
6
7struct Fp8VectorDequantize {
8    out_ty: DType,
9}
10
11impl Fp8VectorDequantize {
12    fn dispatch_dequant_vector<T: WithDType>(
13        &self,
14        weight: &[F8E4M3],
15        scale: &[f32],
16        _weight_l: &candle_core::Layout,
17        scale_l: &candle_core::Layout,
18    ) -> candle_core::Result<Vec<T>> {
19        let num_elements = weight.len();
20        let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
21
22        if scale.len() != num_vectors {
23            candle_core::bail!(
24                "Scale length {} doesn't match expected number of vectors {}",
25                scale.len(),
26                num_vectors
27            );
28        }
29
30        let res = vec![T::zero(); num_elements];
31
32        (0..num_vectors).into_par_iter().for_each(|vector_idx| {
33            let res_ptr = res.as_ptr() as *mut T;
34            let vector_scale = scale[vector_idx * scale_l.stride()[0]];
35            let vector_start = vector_idx * VECTOR_SIZE;
36            let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
37
38            for (idx, &weight_val) in weight[vector_start..vector_end].iter().enumerate() {
39                let global_idx = vector_start + idx;
40                // SAFETY: We know each thread will only update independent values!
41                unsafe {
42                    *res_ptr.wrapping_add(global_idx) =
43                        T::from_f64((weight_val.to_f32() * vector_scale) as f64);
44                }
45            }
46        });
47
48        Ok(res)
49    }
50}
51
52impl CustomOp2 for Fp8VectorDequantize {
53    fn name(&self) -> &'static str {
54        "fp8-vector-dequantize"
55    }
56
57    fn cpu_fwd(
58        &self,
59        scale_s: &candle_core::CpuStorage,
60        scale_l: &candle_core::Layout,
61        weight_s: &candle_core::CpuStorage,
62        weight_l: &candle_core::Layout,
63    ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
64        let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
65            candle_core::bail!("Expected F8E4M3 weight!");
66        };
67        let candle_core::CpuStorage::F32(scale) = scale_s else {
68            candle_core::bail!("Expected F32 scale!");
69        };
70        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
71            candle_core::bail!("Expected weight to have start offset 0, continuous");
72        }
73        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
74            candle_core::bail!("Expected scales to have start offset 0, continuous");
75        }
76
77        match self.out_ty {
78            DType::F32 => Ok((
79                CpuStorage::F32(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
80                weight_l.shape().clone(),
81            )),
82            DType::BF16 => Ok((
83                CpuStorage::BF16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
84                weight_l.shape().clone(),
85            )),
86            DType::F16 => Ok((
87                CpuStorage::F16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
88                weight_l.shape().clone(),
89            )),
90            other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
91        }
92    }
93
94    #[cfg(feature = "cuda")]
95    fn cuda_fwd(
96        &self,
97        scale_s: &candle_core::CudaStorage,
98        scale_l: &candle_core::Layout,
99        weight_s: &candle_core::CudaStorage,
100        weight_l: &candle_core::Layout,
101    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
102        use candle_core::{backend::BackendStorage, CudaStorage};
103        use half::{bf16, f16};
104
105        use crate::{utils::slice_ptr, vector_fp8::ffi};
106
107        if !ffi::HAVE_VECTOR_DEQUANT_KERNELS {
108            candle_core::bail!("Do not have vector FP8 dequant kernels.");
109        }
110
111        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
112            candle_core::bail!("Expected weight to have start offset 0, continuous");
113        }
114        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
115            candle_core::bail!("Expected scales to have start offset 0, continuous");
116        }
117
118        let dev = weight_s.device();
119        let num_elements = weight_l.shape().elem_count();
120
121        let (weight, _weight_guard) =
122            slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
123        let (scale, _scale_guard) =
124            slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
125
126        let res = match self.out_ty {
127            DType::F32 => {
128                let output = dev.alloc_zeros::<f32>(num_elements)?;
129                let (output_ptr, output_guard) = slice_ptr(&output, 0);
130                unsafe {
131                    ffi::launch_dequant_fp8_vector_kernel_f32(
132                        weight as *const _,
133                        scale as *const _,
134                        output_ptr as *mut _,
135                        num_elements,
136                        dev.cuda_stream().cu_stream(),
137                    )
138                };
139                drop(output_guard);
140                CudaStorage::wrap_cuda_slice(output, dev.clone())
141            }
142            DType::F16 => {
143                let output = dev.alloc_zeros::<f16>(num_elements)?;
144                let (output_ptr, output_guard) = slice_ptr(&output, 0);
145                unsafe {
146                    ffi::launch_dequant_fp8_vector_kernel_f16(
147                        weight as *const _,
148                        scale as *const _,
149                        output_ptr as *mut _,
150                        num_elements,
151                        dev.cuda_stream().cu_stream(),
152                    )
153                };
154                drop(output_guard);
155                CudaStorage::wrap_cuda_slice(output, dev.clone())
156            }
157            DType::BF16 => {
158                let output = dev.alloc_zeros::<bf16>(num_elements)?;
159                let (output_ptr, output_guard) = slice_ptr(&output, 0);
160                unsafe {
161                    ffi::launch_dequant_fp8_vector_kernel_bf16(
162                        weight as *const _,
163                        scale as *const _,
164                        output_ptr as *mut _,
165                        num_elements,
166                        dev.cuda_stream().cu_stream(),
167                    )
168                };
169                drop(output_guard);
170                CudaStorage::wrap_cuda_slice(output, dev.clone())
171            }
172            other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
173        };
174
175        Ok((res, weight_l.shape().clone()))
176    }
177
178    #[cfg(feature = "metal")]
179    fn metal_fwd(
180        &self,
181        _scale_s: &candle_core::MetalStorage,
182        _scale_l: &candle_core::Layout,
183        _weight_s: &candle_core::MetalStorage,
184        _weight_l: &candle_core::Layout,
185    ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
186        candle_core::bail!("FP8 vector dequantization not yet implemented for Metal");
187    }
188}
189
190/// FP8 vector dequantize.
191/// - Expects weight to be fp8
192/// - Expects inv_scales to be f32
193/// - weight * inv_scale = dequantized
194pub fn fp8_vector_dequantize(
195    weight: &Tensor,
196    inv_scales: &Tensor,
197    out_ty: DType,
198) -> Result<Tensor> {
199    inv_scales.apply_op2_no_bwd(weight, &Fp8VectorDequantize { out_ty })
200}
201
202/// CPU implementation of vector quantization
203fn cpu_vector_quantize<T: WithDType>(
204    input: &[T],
205    num_elements: usize,
206) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
207    let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
208
209    let weight = vec![F8E4M3::from_f32(0.0); num_elements];
210    let scale = vec![0f32; num_vectors];
211
212    (0..num_vectors).into_par_iter().for_each(|vector_idx| {
213        let weight_ptr = weight.as_ptr() as *mut F8E4M3;
214        let scale_ptr = scale.as_ptr() as *mut f32;
215
216        let vector_start = vector_idx * VECTOR_SIZE;
217        let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
218
219        // Find max absolute value in vector
220        let mut max_abs = 0f32;
221        for &input_val in &input[vector_start..vector_end] {
222            let val = input_val.to_f64() as f32;
223            let abs_val = val.abs();
224            if abs_val > max_abs {
225                max_abs = abs_val;
226            }
227        }
228
229        // Calculate scale
230        let vector_scale = if max_abs > 0.0 {
231            max_abs / 448.0
232        } else {
233            1e-12
234        };
235
236        // SAFETY: We know each thread will only update independent values!
237        unsafe {
238            *scale_ptr.wrapping_add(vector_idx) = vector_scale;
239        }
240
241        // Quantize values
242        for (idx, &input_val) in input[vector_start..vector_end].iter().enumerate() {
243            let global_idx = vector_start + idx;
244            let val = input_val.to_f64() as f32;
245            let scaled_val = (val / vector_scale).clamp(-448.0, 448.0);
246
247            // SAFETY: We know each thread will only update independent values!
248            unsafe {
249                *weight_ptr.wrapping_add(global_idx) = F8E4M3::from_f32(scaled_val);
250            }
251        }
252    });
253
254    Ok((weight, scale))
255}
256
257/// FP8 vector quantize for CPU
258fn cpu_fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
259    let num_elements = input.shape().elem_count();
260    let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
261
262    let (weight_data, scale_data) = match input.dtype() {
263        DType::F32 => {
264            let data = input.to_vec1::<f32>()?;
265            cpu_vector_quantize(&data, num_elements)?
266        }
267        DType::F16 => {
268            let data = input.to_vec1::<half::f16>()?;
269            cpu_vector_quantize(&data, num_elements)?
270        }
271        DType::BF16 => {
272            let data = input.to_vec1::<half::bf16>()?;
273            cpu_vector_quantize(&data, num_elements)?
274        }
275        other => candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}"),
276    };
277
278    // Create tensors from the raw data
279    let weight = Tensor::from_vec(weight_data, input.shape(), input.device())?;
280    let scale = Tensor::from_vec(scale_data, num_vectors, input.device())?;
281
282    Ok((weight, scale))
283}
284
285/// FP8 vector quantize.
286/// - Expects input to be f32, f16, or bf16
287/// - Returns a tuple of (quantized_weight, scales)
288/// - quantized_weight is fp8
289/// - scales is f32
290/// - Each scale corresponds to a vector of 128 elements
291pub fn fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
292    // Check that tensor size is divisible by 128
293    let num_elements = input.shape().elem_count();
294    if num_elements % VECTOR_SIZE != 0 {
295        candle_core::bail!(
296            "Tensor size {} must be divisible by {} for vector FP8 quantization",
297            num_elements,
298            VECTOR_SIZE
299        );
300    }
301
302    // Check if we should use CPU implementation
303    if matches!(input.device(), candle_core::Device::Cpu) {
304        return cpu_fp8_vector_quantize(input);
305    }
306
307    #[cfg(feature = "cuda")]
308    {
309        use candle_core::{CudaStorage, Device, Storage};
310        use half::{bf16, f16};
311
312        use crate::{utils::slice_ptr, vector_fp8::ffi};
313
314        if matches!(input.device(), Device::Cuda(_)) {
315            if !ffi::HAVE_VECTOR_QUANT_KERNELS {
316                candle_core::bail!("Do not have vector FP8 quant kernels.");
317            }
318
319            let input_l = input.layout();
320            if input_l.start_offset() != 0 || !input_l.is_contiguous() {
321                candle_core::bail!("Expected input to have start offset 0, continuous");
322            }
323
324            let dev = match input.device() {
325                Device::Cuda(dev) => dev,
326                _ => unreachable!(),
327            };
328
329            let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
330
331            // Allocate output buffers
332            let weight_output = dev.alloc_zeros::<F8E4M3>(num_elements)?;
333            let scale_output = dev.alloc_zeros::<f32>(num_vectors)?;
334
335            let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
336            let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
337
338            match input.dtype() {
339                DType::F32 => {
340                    let input_storage = input.storage_and_layout().0;
341                    let input_s = match &*input_storage {
342                        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
343                        _ => candle_core::bail!("Expected CUDA storage"),
344                    };
345                    let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
346                    unsafe {
347                        ffi::launch_quant_fp8_vector_kernel_f32(
348                            input_ptr as *const _,
349                            weight_ptr as *mut _,
350                            scale_ptr as *mut _,
351                            num_elements,
352                            dev.cuda_stream().cu_stream(),
353                        )
354                    };
355                }
356                DType::F16 => {
357                    let input_storage = input.storage_and_layout().0;
358                    let input_s = match &*input_storage {
359                        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
360                        _ => candle_core::bail!("Expected CUDA storage"),
361                    };
362                    let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
363                    unsafe {
364                        ffi::launch_quant_fp8_vector_kernel_f16(
365                            input_ptr as *const _,
366                            weight_ptr as *mut _,
367                            scale_ptr as *mut _,
368                            num_elements,
369                            dev.cuda_stream().cu_stream(),
370                        )
371                    };
372                }
373                DType::BF16 => {
374                    let input_storage = input.storage_and_layout().0;
375                    let input_s = match &*input_storage {
376                        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
377                        _ => candle_core::bail!("Expected CUDA storage"),
378                    };
379                    let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
380                    unsafe {
381                        ffi::launch_quant_fp8_vector_kernel_bf16(
382                            input_ptr as *const _,
383                            weight_ptr as *mut _,
384                            scale_ptr as *mut _,
385                            num_elements,
386                            dev.cuda_stream().cu_stream(),
387                        )
388                    };
389                }
390                other => {
391                    candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}")
392                }
393            }
394
395            // Drop guards before moving the buffers
396            drop(_weight_guard);
397            drop(_scale_guard);
398
399            // Create weight tensor by wrapping the CUDA storage
400            let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
401            let weight = Tensor::from((Storage::Cuda(weight_storage), input.shape().clone()));
402
403            // Create scale tensor
404            let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
405            let scale = Tensor::from((
406                Storage::Cuda(scale_storage),
407                candle_core::Shape::from_dims(&[num_vectors]),
408            ));
409
410            Ok((weight, scale))
411        } else {
412            candle_core::bail!("Expected CUDA device.");
413        }
414    }
415
416    #[cfg(not(feature = "cuda"))]
417    {
418        candle_core::bail!("FP8 vector quantization on non-CPU devices requires CUDA feature");
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use candle_core::{DType, Device, Result, Tensor};
426
427    #[test]
428    fn test_fp8_vector_dequant() -> Result<()> {
429        let dev = &Device::Cpu;
430        let num_elements = 256; // 2 vectors of 128 elements
431        let weight = Tensor::ones(num_elements, DType::F8E4M3, dev)?;
432        let scales = Tensor::new(&[2.0f32, 3.0f32], dev)?; // 2 scales for 2 vectors
433
434        let dequant = fp8_vector_dequantize(&weight, &scales, DType::F32)?;
435        let res = dequant.to_vec1::<f32>()?;
436
437        // First 128 elements should be 2.0, next 128 should be 3.0
438        for &val in &res[0..128] {
439            assert_eq!(val, 2.0);
440        }
441        for &val in &res[128..256] {
442            assert_eq!(val, 3.0);
443        }
444
445        Ok(())
446    }
447
448    #[test]
449    fn test_fp8_vector_quant_cpu() -> Result<()> {
450        let dev = &Device::Cpu;
451
452        // Create test input with 256 elements (2 vectors)
453        let input = Tensor::randn(0f32, 2f32, 256, dev)?;
454
455        // Quantize
456        let (quantized, scales) = fp8_vector_quantize(&input)?;
457
458        // Verify shapes
459        assert_eq!(quantized.shape(), input.shape());
460        assert_eq!(scales.dims1()?, 2); // 256/128 = 2 vectors
461
462        // Dequantize
463        let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
464
465        // Check that shapes match
466        assert_eq!(dequantized.shape(), input.shape());
467
468        // The values won't be exactly the same due to quantization loss,
469        // but they should be reasonably close
470        let input_vec = input.to_vec1::<f32>()?;
471        let dequant_vec = dequantized.to_vec1::<f32>()?;
472
473        let mut max_error = 0f32;
474        for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
475            let error = (val_in - val_out).abs();
476            max_error = max_error.max(error);
477        }
478
479        // FP8 E4M3 has limited precision, so we expect some error
480        assert!(max_error < 0.27, "Max error {max_error} is too large");
481
482        Ok(())
483    }
484
485    #[cfg(feature = "cuda")]
486    #[test]
487    fn test_fp8_vector_quant_dequant_roundtrip() -> Result<()> {
488        let dev = &Device::new_cuda(0)?;
489
490        // Create test input with 256 elements (2 vectors)
491        let input = Tensor::randn(0f32, 2f32, 256, dev)?;
492
493        // Quantize
494        let (quantized, scales) = fp8_vector_quantize(&input)?;
495
496        // Verify shapes
497        assert_eq!(quantized.shape(), input.shape());
498        assert_eq!(scales.dims1()?, 2); // 256/128 = 2 vectors
499
500        // Dequantize
501        let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
502
503        // Check that shapes match
504        assert_eq!(dequantized.shape(), input.shape());
505
506        // The values won't be exactly the same due to quantization loss,
507        // but they should be reasonably close
508        let input_vec = input.to_vec1::<f32>()?;
509        let dequant_vec = dequantized.to_vec1::<f32>()?;
510
511        let mut max_error = 0f32;
512        for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
513            let error = (val_in - val_out).abs();
514            max_error = max_error.max(error);
515        }
516
517        // FP8 E4M3 has limited precision, so we expect some error
518        assert!(max_error < 0.24, "Max error {} is too large", max_error);
519
520        Ok(())
521    }
522
523    #[cfg(feature = "cuda")]
524    #[test]
525    fn test_fp8_vector_cpu_cuda_equivalence() -> Result<()> {
526        let cpu_dev = &Device::Cpu;
527        let cuda_dev = &Device::new_cuda(0)?;
528
529        // Create the same input data on both devices
530        let input_data: Vec<f32> = (0..256).map(|i| ((i as f32) - 128.0) / 10.0).collect();
531        let cpu_input = Tensor::from_vec(input_data.clone(), 256, cpu_dev)?;
532        let cuda_input = Tensor::from_vec(input_data, 256, cuda_dev)?;
533
534        // Quantize on CPU
535        let (cpu_quantized, cpu_scales) = fp8_vector_quantize(&cpu_input)?;
536
537        // Quantize on CUDA
538        let (cuda_quantized, cuda_scales) = fp8_vector_quantize(&cuda_input)?;
539
540        // Move CUDA results to CPU for comparison
541        let cuda_quantized_cpu = cuda_quantized.to_device(cpu_dev)?;
542        let cuda_scales_cpu = cuda_scales.to_device(cpu_dev)?;
543
544        // Compare quantized weights
545        let cpu_quant_vec = cpu_quantized.to_vec1::<F8E4M3>()?;
546        let cuda_quant_vec = cuda_quantized_cpu.to_vec1::<F8E4M3>()?;
547
548        assert_eq!(cpu_quant_vec.len(), cuda_quant_vec.len());
549
550        let mut num_differences = 0;
551        for (i, (cpu_val, cuda_val)) in cpu_quant_vec.iter().zip(cuda_quant_vec.iter()).enumerate()
552        {
553            if cpu_val.to_f32() != cuda_val.to_f32() {
554                // Allow small differences due to floating point precision
555                let diff = (cpu_val.to_f32() - cuda_val.to_f32()).abs();
556                if diff > 1e-6 {
557                    num_differences += 1;
558                    if num_differences < 10 {
559                        println!(
560                            "Difference at index {}: CPU={}, CUDA={}, diff={}",
561                            i,
562                            cpu_val.to_f32(),
563                            cuda_val.to_f32(),
564                            diff
565                        );
566                    }
567                }
568            }
569        }
570
571        // FP8 quantization should be deterministic, so we expect very few differences
572        assert!(
573            num_differences < 5,
574            "Too many differences between CPU and CUDA quantization: {}",
575            num_differences
576        );
577
578        // Compare scales
579        let cpu_scales_vec = cpu_scales.to_vec1::<f32>()?;
580        let cuda_scales_vec = cuda_scales_cpu.to_vec1::<f32>()?;
581
582        assert_eq!(cpu_scales_vec.len(), cuda_scales_vec.len());
583
584        for (i, (cpu_scale, cuda_scale)) in cpu_scales_vec
585            .iter()
586            .zip(cuda_scales_vec.iter())
587            .enumerate()
588        {
589            let scale_diff = (cpu_scale - cuda_scale).abs();
590            assert!(
591                scale_diff < 1e-6,
592                "Scale difference at index {}: CPU={}, CUDA={}, diff={}",
593                i,
594                cpu_scale,
595                cuda_scale,
596                scale_diff
597            );
598        }
599
600        // Also test that dequantization gives the same results
601        let cpu_dequant = fp8_vector_dequantize(&cpu_quantized, &cpu_scales, DType::F32)?;
602        let cuda_dequant =
603            fp8_vector_dequantize(&cuda_quantized_cpu, &cuda_scales_cpu, DType::F32)?;
604
605        let cpu_dequant_vec = cpu_dequant.to_vec1::<f32>()?;
606        let cuda_dequant_vec = cuda_dequant.to_vec1::<f32>()?;
607
608        let mut max_dequant_diff = 0f32;
609        for (cpu_val, cuda_val) in cpu_dequant_vec.iter().zip(cuda_dequant_vec.iter()) {
610            let diff = (cpu_val - cuda_val).abs();
611            max_dequant_diff = max_dequant_diff.max(diff);
612        }
613
614        assert!(
615            max_dequant_diff < 1e-5,
616            "Max dequantization difference too large: {}",
617            max_dequant_diff
618        );
619
620        Ok(())
621    }
622}