mistralrs_quant/vector_fp8/
ops.rs

1#[cfg(feature = "cuda")]
2use candle_core::from_storage_no_op;
3use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
4use float8::F8E4M3;
5use rayon::iter::{IntoParallelIterator, ParallelIterator};
6
7use super::VECTOR_SIZE;
8
9struct Fp8VectorDequantize {
10    out_ty: DType,
11}
12
13impl Fp8VectorDequantize {
14    fn dispatch_dequant_vector<T: WithDType>(
15        &self,
16        weight: &[F8E4M3],
17        scale: &[f32],
18        _weight_l: &candle_core::Layout,
19        scale_l: &candle_core::Layout,
20    ) -> candle_core::Result<Vec<T>> {
21        let num_elements = weight.len();
22        let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
23
24        if scale.len() != num_vectors {
25            candle_core::bail!(
26                "Scale length {} doesn't match expected number of vectors {}",
27                scale.len(),
28                num_vectors
29            );
30        }
31
32        let res = vec![T::zero(); num_elements];
33
34        (0..num_vectors).into_par_iter().for_each(|vector_idx| {
35            let res_ptr = res.as_ptr() as *mut T;
36            let vector_scale = scale[vector_idx * scale_l.stride()[0]];
37            let vector_start = vector_idx * VECTOR_SIZE;
38            let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
39
40            for (idx, &weight_val) in weight[vector_start..vector_end].iter().enumerate() {
41                let global_idx = vector_start + idx;
42                // SAFETY: We know each thread will only update independent values!
43                unsafe {
44                    *res_ptr.wrapping_add(global_idx) =
45                        T::from_f64((weight_val.to_f32() * vector_scale) as f64);
46                }
47            }
48        });
49
50        Ok(res)
51    }
52}
53
54impl CustomOp2 for Fp8VectorDequantize {
55    fn name(&self) -> &'static str {
56        "fp8-vector-dequantize"
57    }
58
59    fn cpu_fwd(
60        &self,
61        scale_s: &candle_core::CpuStorage,
62        scale_l: &candle_core::Layout,
63        weight_s: &candle_core::CpuStorage,
64        weight_l: &candle_core::Layout,
65    ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
66        let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
67            candle_core::bail!("Expected F8E4M3 weight!");
68        };
69        let candle_core::CpuStorage::F32(scale) = scale_s else {
70            candle_core::bail!("Expected F32 scale!");
71        };
72        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
73            candle_core::bail!("Expected weight to have start offset 0, continuous");
74        }
75        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
76            candle_core::bail!("Expected scales to have start offset 0, continuous");
77        }
78
79        match self.out_ty {
80            DType::F32 => Ok((
81                CpuStorage::F32(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
82                weight_l.shape().clone(),
83            )),
84            DType::BF16 => Ok((
85                CpuStorage::BF16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
86                weight_l.shape().clone(),
87            )),
88            DType::F16 => Ok((
89                CpuStorage::F16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
90                weight_l.shape().clone(),
91            )),
92            other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
93        }
94    }
95
96    #[cfg(feature = "cuda")]
97    fn cuda_fwd(
98        &self,
99        scale_s: &candle_core::CudaStorage,
100        scale_l: &candle_core::Layout,
101        weight_s: &candle_core::CudaStorage,
102        weight_l: &candle_core::Layout,
103    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
104        use candle_core::{backend::BackendStorage, CudaStorage};
105        use half::{bf16, f16};
106
107        use crate::{utils::slice_ptr, vector_fp8::ffi};
108
109        if !ffi::HAVE_VECTOR_DEQUANT_KERNELS {
110            candle_core::bail!("Do not have vector FP8 dequant kernels.");
111        }
112
113        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
114            candle_core::bail!("Expected weight to have start offset 0, continuous");
115        }
116        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
117            candle_core::bail!("Expected scales to have start offset 0, continuous");
118        }
119
120        let dev = weight_s.device();
121        let num_elements = weight_l.shape().elem_count();
122
123        let (weight, _weight_guard) =
124            slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
125        let (scale, _scale_guard) =
126            slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
127
128        let res = match self.out_ty {
129            DType::F32 => {
130                let output = dev.alloc_zeros::<f32>(num_elements)?;
131                let (output_ptr, output_guard) = slice_ptr(&output, 0);
132                unsafe {
133                    ffi::launch_dequant_fp8_vector_kernel_f32(
134                        weight as *const _,
135                        scale as *const _,
136                        output_ptr as *mut _,
137                        num_elements,
138                        dev.cuda_stream().cu_stream(),
139                    )
140                };
141                drop(output_guard);
142                CudaStorage::wrap_cuda_slice(output, dev.clone())
143            }
144            DType::F16 => {
145                let output = dev.alloc_zeros::<f16>(num_elements)?;
146                let (output_ptr, output_guard) = slice_ptr(&output, 0);
147                unsafe {
148                    ffi::launch_dequant_fp8_vector_kernel_f16(
149                        weight as *const _,
150                        scale as *const _,
151                        output_ptr as *mut _,
152                        num_elements,
153                        dev.cuda_stream().cu_stream(),
154                    )
155                };
156                drop(output_guard);
157                CudaStorage::wrap_cuda_slice(output, dev.clone())
158            }
159            DType::BF16 => {
160                let output = dev.alloc_zeros::<bf16>(num_elements)?;
161                let (output_ptr, output_guard) = slice_ptr(&output, 0);
162                unsafe {
163                    ffi::launch_dequant_fp8_vector_kernel_bf16(
164                        weight as *const _,
165                        scale as *const _,
166                        output_ptr as *mut _,
167                        num_elements,
168                        dev.cuda_stream().cu_stream(),
169                    )
170                };
171                drop(output_guard);
172                CudaStorage::wrap_cuda_slice(output, dev.clone())
173            }
174            other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
175        };
176
177        Ok((res, weight_l.shape().clone()))
178    }
179
180    #[cfg(feature = "metal")]
181    fn metal_fwd(
182        &self,
183        _scale_s: &candle_core::MetalStorage,
184        _scale_l: &candle_core::Layout,
185        _weight_s: &candle_core::MetalStorage,
186        _weight_l: &candle_core::Layout,
187    ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
188        candle_core::bail!("FP8 vector dequantization not yet implemented for Metal");
189    }
190}
191
192/// FP8 vector dequantize.
193/// - Expects weight to be fp8
194/// - Expects inv_scales to be f32
195/// - weight * inv_scale = dequantized
196pub fn fp8_vector_dequantize(
197    weight: &Tensor,
198    inv_scales: &Tensor,
199    out_ty: DType,
200) -> Result<Tensor> {
201    inv_scales.apply_op2_no_bwd(weight, &Fp8VectorDequantize { out_ty })
202}
203
204/// CPU implementation of vector quantization
205fn cpu_vector_quantize<T: WithDType>(
206    input: &[T],
207    num_elements: usize,
208) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
209    let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
210
211    let weight = vec![F8E4M3::from_f32(0.0); num_elements];
212    let scale = vec![0f32; num_vectors];
213
214    (0..num_vectors).into_par_iter().for_each(|vector_idx| {
215        let weight_ptr = weight.as_ptr() as *mut F8E4M3;
216        let scale_ptr = scale.as_ptr() as *mut f32;
217
218        let vector_start = vector_idx * VECTOR_SIZE;
219        let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
220
221        // Find max absolute value in vector
222        let mut max_abs = 0f32;
223        for &input_val in &input[vector_start..vector_end] {
224            let val = input_val.to_f64() as f32;
225            let abs_val = val.abs();
226            if abs_val > max_abs {
227                max_abs = abs_val;
228            }
229        }
230
231        // Calculate scale
232        let vector_scale = if max_abs > 0.0 {
233            max_abs / 448.0
234        } else {
235            1e-12
236        };
237
238        // SAFETY: We know each thread will only update independent values!
239        unsafe {
240            *scale_ptr.wrapping_add(vector_idx) = vector_scale;
241        }
242
243        // Quantize values
244        for (idx, &input_val) in input[vector_start..vector_end].iter().enumerate() {
245            let global_idx = vector_start + idx;
246            let val = input_val.to_f64() as f32;
247            let scaled_val = (val / vector_scale).clamp(-448.0, 448.0);
248
249            // SAFETY: We know each thread will only update independent values!
250            unsafe {
251                *weight_ptr.wrapping_add(global_idx) = F8E4M3::from_f32(scaled_val);
252            }
253        }
254    });
255
256    Ok((weight, scale))
257}
258
259/// FP8 vector quantize for CPU
260fn cpu_fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
261    let num_elements = input.shape().elem_count();
262    let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
263
264    let (weight_data, scale_data) = match input.dtype() {
265        DType::F32 => {
266            let data = input.to_vec1::<f32>()?;
267            cpu_vector_quantize(&data, num_elements)?
268        }
269        DType::F16 => {
270            let data = input.to_vec1::<half::f16>()?;
271            cpu_vector_quantize(&data, num_elements)?
272        }
273        DType::BF16 => {
274            let data = input.to_vec1::<half::bf16>()?;
275            cpu_vector_quantize(&data, num_elements)?
276        }
277        other => candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}"),
278    };
279
280    // Create tensors from the raw data
281    let weight = Tensor::from_vec(weight_data, input.shape(), input.device())?;
282    let scale = Tensor::from_vec(scale_data, num_vectors, input.device())?;
283
284    Ok((weight, scale))
285}
286
287/// FP8 vector quantize.
288/// - Expects input to be f32, f16, or bf16
289/// - Returns a tuple of (quantized_weight, scales)
290/// - quantized_weight is fp8
291/// - scales is f32
292/// - Each scale corresponds to a vector of 128 elements
293pub fn fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
294    // Check that tensor size is divisible by 128
295    let num_elements = input.shape().elem_count();
296    if num_elements % VECTOR_SIZE != 0 {
297        candle_core::bail!(
298            "Tensor size {} must be divisible by {} for vector FP8 quantization",
299            num_elements,
300            VECTOR_SIZE
301        );
302    }
303
304    // Check if we should use CPU implementation
305    if matches!(input.device(), candle_core::Device::Cpu) {
306        return cpu_fp8_vector_quantize(input);
307    }
308
309    #[cfg(feature = "cuda")]
310    {
311        use candle_core::{CudaStorage, Device, Storage};
312        use half::{bf16, f16};
313
314        use crate::{utils::slice_ptr, vector_fp8::ffi};
315
316        if matches!(input.device(), Device::Cuda(_)) {
317            if !ffi::HAVE_VECTOR_QUANT_KERNELS {
318                candle_core::bail!("Do not have vector FP8 quant kernels.");
319            }
320
321            let input_l = input.layout();
322            if input_l.start_offset() != 0 || !input_l.is_contiguous() {
323                candle_core::bail!("Expected input to have start offset 0, continuous");
324            }
325
326            let dev = match input.device() {
327                Device::Cuda(dev) => dev,
328                _ => unreachable!(),
329            };
330
331            let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
332
333            // Allocate output buffers
334            let weight_output = dev.alloc_zeros::<F8E4M3>(num_elements)?;
335            let scale_output = dev.alloc_zeros::<f32>(num_vectors)?;
336
337            let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
338            let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
339
340            match input.dtype() {
341                DType::F32 => {
342                    let input_storage = input.storage_and_layout().0;
343                    let input_s = match &*input_storage {
344                        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
345                        _ => candle_core::bail!("Expected CUDA storage"),
346                    };
347                    let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
348                    unsafe {
349                        ffi::launch_quant_fp8_vector_kernel_f32(
350                            input_ptr as *const _,
351                            weight_ptr as *mut _,
352                            scale_ptr as *mut _,
353                            num_elements,
354                            dev.cuda_stream().cu_stream(),
355                        )
356                    };
357                }
358                DType::F16 => {
359                    let input_storage = input.storage_and_layout().0;
360                    let input_s = match &*input_storage {
361                        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
362                        _ => candle_core::bail!("Expected CUDA storage"),
363                    };
364                    let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
365                    unsafe {
366                        ffi::launch_quant_fp8_vector_kernel_f16(
367                            input_ptr as *const _,
368                            weight_ptr as *mut _,
369                            scale_ptr as *mut _,
370                            num_elements,
371                            dev.cuda_stream().cu_stream(),
372                        )
373                    };
374                }
375                DType::BF16 => {
376                    let input_storage = input.storage_and_layout().0;
377                    let input_s = match &*input_storage {
378                        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
379                        _ => candle_core::bail!("Expected CUDA storage"),
380                    };
381                    let (input_ptr, _input_guard) = slice_ptr(&input_s, input_l.start_offset());
382                    unsafe {
383                        ffi::launch_quant_fp8_vector_kernel_bf16(
384                            input_ptr as *const _,
385                            weight_ptr as *mut _,
386                            scale_ptr as *mut _,
387                            num_elements,
388                            dev.cuda_stream().cu_stream(),
389                        )
390                    };
391                }
392                other => {
393                    candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}")
394                }
395            }
396
397            // Drop guards before moving the buffers
398            drop(_weight_guard);
399            drop(_scale_guard);
400
401            // Create weight tensor by wrapping the CUDA storage
402            let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
403            let weight =
404                from_storage_no_op(Storage::Cuda(weight_storage), input.shape().clone(), false);
405
406            // Create scale tensor
407            let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
408            let scale = from_storage_no_op(
409                Storage::Cuda(scale_storage),
410                candle_core::Shape::from_dims(&[num_vectors]),
411                false,
412            );
413
414            return Ok((weight, scale));
415        } else {
416            candle_core::bail!("Expected CUDA device.");
417        }
418    }
419
420    #[cfg(not(feature = "cuda"))]
421    {
422        candle_core::bail!("FP8 vector quantization on non-CPU devices requires CUDA feature");
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use candle_core::{DType, Device, Result, Tensor};
430
431    #[test]
432    fn test_fp8_vector_dequant() -> Result<()> {
433        let dev = &Device::Cpu;
434        let num_elements = 256; // 2 vectors of 128 elements
435        let weight = Tensor::ones(num_elements, DType::F8E4M3, dev)?;
436        let scales = Tensor::new(&[2.0f32, 3.0f32], dev)?; // 2 scales for 2 vectors
437
438        let dequant = fp8_vector_dequantize(&weight, &scales, DType::F32)?;
439        let res = dequant.to_vec1::<f32>()?;
440
441        // First 128 elements should be 2.0, next 128 should be 3.0
442        for &val in &res[0..128] {
443            assert_eq!(val, 2.0);
444        }
445        for &val in &res[128..256] {
446            assert_eq!(val, 3.0);
447        }
448
449        Ok(())
450    }
451
452    #[test]
453    fn test_fp8_vector_quant_cpu() -> Result<()> {
454        let dev = &Device::Cpu;
455
456        // Create test input with 256 elements (2 vectors)
457        let input = Tensor::randn(0f32, 2f32, 256, dev)?;
458
459        // Quantize
460        let (quantized, scales) = fp8_vector_quantize(&input)?;
461
462        // Verify shapes
463        assert_eq!(quantized.shape(), input.shape());
464        assert_eq!(scales.dims1()?, 2); // 256/128 = 2 vectors
465
466        // Dequantize
467        let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
468
469        // Check that shapes match
470        assert_eq!(dequantized.shape(), input.shape());
471
472        // The values won't be exactly the same due to quantization loss,
473        // but they should be reasonably close
474        let input_vec = input.to_vec1::<f32>()?;
475        let dequant_vec = dequantized.to_vec1::<f32>()?;
476
477        let mut max_error = 0f32;
478        for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
479            let error = (val_in - val_out).abs();
480            max_error = max_error.max(error);
481        }
482
483        // FP8 E4M3 has limited precision, so we expect some error
484        assert!(max_error < 0.25, "Max error {max_error} is too large");
485
486        Ok(())
487    }
488
489    #[cfg(feature = "cuda")]
490    #[test]
491    fn test_fp8_vector_quant_dequant_roundtrip() -> Result<()> {
492        let dev = &Device::new_cuda(0)?;
493
494        // Create test input with 256 elements (2 vectors)
495        let input = Tensor::randn(0f32, 2f32, 256, dev)?;
496
497        // Quantize
498        let (quantized, scales) = fp8_vector_quantize(&input)?;
499
500        // Verify shapes
501        assert_eq!(quantized.shape(), input.shape());
502        assert_eq!(scales.dims1()?, 2); // 256/128 = 2 vectors
503
504        // Dequantize
505        let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
506
507        // Check that shapes match
508        assert_eq!(dequantized.shape(), input.shape());
509
510        // The values won't be exactly the same due to quantization loss,
511        // but they should be reasonably close
512        let input_vec = input.to_vec1::<f32>()?;
513        let dequant_vec = dequantized.to_vec1::<f32>()?;
514
515        let mut max_error = 0f32;
516        for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
517            let error = (val_in - val_out).abs();
518            max_error = max_error.max(error);
519        }
520
521        // FP8 E4M3 has limited precision, so we expect some error
522        assert!(max_error < 0.24, "Max error {} is too large", max_error);
523
524        Ok(())
525    }
526
527    #[cfg(feature = "cuda")]
528    #[test]
529    fn test_fp8_vector_cpu_cuda_equivalence() -> Result<()> {
530        let cpu_dev = &Device::Cpu;
531        let cuda_dev = &Device::new_cuda(0)?;
532
533        // Create the same input data on both devices
534        let input_data: Vec<f32> = (0..256).map(|i| ((i as f32) - 128.0) / 10.0).collect();
535        let cpu_input = Tensor::from_vec(input_data.clone(), 256, cpu_dev)?;
536        let cuda_input = Tensor::from_vec(input_data, 256, cuda_dev)?;
537
538        // Quantize on CPU
539        let (cpu_quantized, cpu_scales) = fp8_vector_quantize(&cpu_input)?;
540
541        // Quantize on CUDA
542        let (cuda_quantized, cuda_scales) = fp8_vector_quantize(&cuda_input)?;
543
544        // Move CUDA results to CPU for comparison
545        let cuda_quantized_cpu = cuda_quantized.to_device(cpu_dev)?;
546        let cuda_scales_cpu = cuda_scales.to_device(cpu_dev)?;
547
548        // Compare quantized weights
549        let cpu_quant_vec = cpu_quantized.to_vec1::<F8E4M3>()?;
550        let cuda_quant_vec = cuda_quantized_cpu.to_vec1::<F8E4M3>()?;
551
552        assert_eq!(cpu_quant_vec.len(), cuda_quant_vec.len());
553
554        let mut num_differences = 0;
555        for (i, (cpu_val, cuda_val)) in cpu_quant_vec.iter().zip(cuda_quant_vec.iter()).enumerate()
556        {
557            if cpu_val.to_f32() != cuda_val.to_f32() {
558                // Allow small differences due to floating point precision
559                let diff = (cpu_val.to_f32() - cuda_val.to_f32()).abs();
560                if diff > 1e-6 {
561                    num_differences += 1;
562                    if num_differences < 10 {
563                        println!(
564                            "Difference at index {}: CPU={}, CUDA={}, diff={}",
565                            i,
566                            cpu_val.to_f32(),
567                            cuda_val.to_f32(),
568                            diff
569                        );
570                    }
571                }
572            }
573        }
574
575        // FP8 quantization should be deterministic, so we expect very few differences
576        assert!(
577            num_differences < 5,
578            "Too many differences between CPU and CUDA quantization: {}",
579            num_differences
580        );
581
582        // Compare scales
583        let cpu_scales_vec = cpu_scales.to_vec1::<f32>()?;
584        let cuda_scales_vec = cuda_scales_cpu.to_vec1::<f32>()?;
585
586        assert_eq!(cpu_scales_vec.len(), cuda_scales_vec.len());
587
588        for (i, (cpu_scale, cuda_scale)) in cpu_scales_vec
589            .iter()
590            .zip(cuda_scales_vec.iter())
591            .enumerate()
592        {
593            let scale_diff = (cpu_scale - cuda_scale).abs();
594            assert!(
595                scale_diff < 1e-6,
596                "Scale difference at index {}: CPU={}, CUDA={}, diff={}",
597                i,
598                cpu_scale,
599                cuda_scale,
600                scale_diff
601            );
602        }
603
604        // Also test that dequantization gives the same results
605        let cpu_dequant = fp8_vector_dequantize(&cpu_quantized, &cpu_scales, DType::F32)?;
606        let cuda_dequant =
607            fp8_vector_dequantize(&cuda_quantized_cpu, &cuda_scales_cpu, DType::F32)?;
608
609        let cpu_dequant_vec = cpu_dequant.to_vec1::<f32>()?;
610        let cuda_dequant_vec = cuda_dequant.to_vec1::<f32>()?;
611
612        let mut max_dequant_diff = 0f32;
613        for (cpu_val, cuda_val) in cpu_dequant_vec.iter().zip(cuda_dequant_vec.iter()) {
614            let diff = (cpu_val - cuda_val).abs();
615            max_dequant_diff = max_dequant_diff.max(diff);
616        }
617
618        assert!(
619            max_dequant_diff < 1e-5,
620            "Max dequantization difference too large: {}",
621            max_dequant_diff
622        );
623
624        Ok(())
625    }
626}