mistralrs_quant/vector_fp8/
ops.rs

1use candle_core::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
2use float8::F8E4M3;
3use rayon::iter::{IntoParallelIterator, ParallelIterator};
4
5use super::VECTOR_SIZE;
6
7struct Fp8VectorDequantize {
8    out_ty: DType,
9}
10
11impl Fp8VectorDequantize {
12    fn dispatch_dequant_vector<T: WithDType>(
13        &self,
14        weight: &[F8E4M3],
15        scale: &[f32],
16        _weight_l: &candle_core::Layout,
17        scale_l: &candle_core::Layout,
18    ) -> candle_core::Result<Vec<T>> {
19        let num_elements = weight.len();
20        let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
21
22        if scale.len() != num_vectors {
23            candle_core::bail!(
24                "Scale length {} doesn't match expected number of vectors {}",
25                scale.len(),
26                num_vectors
27            );
28        }
29
30        let res = vec![T::zero(); num_elements];
31
32        (0..num_vectors).into_par_iter().for_each(|vector_idx| {
33            let res_ptr = res.as_ptr() as *mut T;
34            let vector_scale = scale[vector_idx * scale_l.stride()[0]];
35            let vector_start = vector_idx * VECTOR_SIZE;
36            let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
37
38            for (idx, &weight_val) in weight[vector_start..vector_end].iter().enumerate() {
39                let global_idx = vector_start + idx;
40                // SAFETY: We know each thread will only update independent values!
41                unsafe {
42                    *res_ptr.wrapping_add(global_idx) =
43                        T::from_f64((weight_val.to_f32() * vector_scale) as f64);
44                }
45            }
46        });
47
48        Ok(res)
49    }
50}
51
52impl CustomOp2 for Fp8VectorDequantize {
53    fn name(&self) -> &'static str {
54        "fp8-vector-dequantize"
55    }
56
57    fn cpu_fwd(
58        &self,
59        scale_s: &candle_core::CpuStorage,
60        scale_l: &candle_core::Layout,
61        weight_s: &candle_core::CpuStorage,
62        weight_l: &candle_core::Layout,
63    ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
64        let candle_core::CpuStorage::F8E4M3(weight) = weight_s else {
65            candle_core::bail!("Expected F8E4M3 weight!");
66        };
67        let candle_core::CpuStorage::F32(scale) = scale_s else {
68            candle_core::bail!("Expected F32 scale!");
69        };
70        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
71            candle_core::bail!("Expected weight to have start offset 0, continuous");
72        }
73        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
74            candle_core::bail!("Expected scales to have start offset 0, continuous");
75        }
76
77        match self.out_ty {
78            DType::F32 => Ok((
79                CpuStorage::F32(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
80                weight_l.shape().clone(),
81            )),
82            DType::BF16 => Ok((
83                CpuStorage::BF16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
84                weight_l.shape().clone(),
85            )),
86            DType::F16 => Ok((
87                CpuStorage::F16(self.dispatch_dequant_vector(weight, scale, weight_l, scale_l)?),
88                weight_l.shape().clone(),
89            )),
90            other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
91        }
92    }
93
94    #[cfg(feature = "cuda")]
95    fn cuda_fwd(
96        &self,
97        scale_s: &candle_core::CudaStorage,
98        scale_l: &candle_core::Layout,
99        weight_s: &candle_core::CudaStorage,
100        weight_l: &candle_core::Layout,
101    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
102        use candle_core::{backend::BackendStorage, CudaStorage};
103        use half::{bf16, f16};
104
105        use crate::{utils::slice_ptr, vector_fp8::ffi};
106
107        if !ffi::HAVE_VECTOR_DEQUANT_KERNELS {
108            candle_core::bail!("Do not have vector FP8 dequant kernels.");
109        }
110
111        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
112            candle_core::bail!("Expected weight to have start offset 0, continuous");
113        }
114        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
115            candle_core::bail!("Expected scales to have start offset 0, continuous");
116        }
117
118        let dev = weight_s.device();
119        let num_elements = weight_l.shape().elem_count();
120
121        let (weight, _weight_guard) =
122            slice_ptr(weight_s.as_cuda_slice::<F8E4M3>()?, weight_l.start_offset());
123        let (scale, _scale_guard) =
124            slice_ptr(scale_s.as_cuda_slice::<f32>()?, scale_l.start_offset());
125
126        let res = match self.out_ty {
127            DType::F32 => {
128                let output = dev.alloc_zeros::<f32>(num_elements)?;
129                let (output_ptr, output_guard) = slice_ptr(&output, 0);
130                unsafe {
131                    ffi::launch_dequant_fp8_vector_kernel_f32(
132                        weight as *const _,
133                        scale as *const _,
134                        output_ptr as *mut _,
135                        num_elements,
136                        dev.cuda_stream().cu_stream(),
137                    )
138                };
139                drop(output_guard);
140                CudaStorage::wrap_cuda_slice(output, dev.clone())
141            }
142            DType::F16 => {
143                let output = dev.alloc_zeros::<f16>(num_elements)?;
144                let (output_ptr, output_guard) = slice_ptr(&output, 0);
145                unsafe {
146                    ffi::launch_dequant_fp8_vector_kernel_f16(
147                        weight as *const _,
148                        scale as *const _,
149                        output_ptr as *mut _,
150                        num_elements,
151                        dev.cuda_stream().cu_stream(),
152                    )
153                };
154                drop(output_guard);
155                CudaStorage::wrap_cuda_slice(output, dev.clone())
156            }
157            DType::BF16 => {
158                let output = dev.alloc_zeros::<bf16>(num_elements)?;
159                let (output_ptr, output_guard) = slice_ptr(&output, 0);
160                unsafe {
161                    ffi::launch_dequant_fp8_vector_kernel_bf16(
162                        weight as *const _,
163                        scale as *const _,
164                        output_ptr as *mut _,
165                        num_elements,
166                        dev.cuda_stream().cu_stream(),
167                    )
168                };
169                drop(output_guard);
170                CudaStorage::wrap_cuda_slice(output, dev.clone())
171            }
172            other => candle_core::bail!("unexpected out type of fp8 vector dequant {other:?}"),
173        };
174
175        Ok((res, weight_l.shape().clone()))
176    }
177
178    #[cfg(feature = "metal")]
179    fn metal_fwd(
180        &self,
181        scale_s: &candle_core::MetalStorage,
182        scale_l: &candle_core::Layout,
183        weight_s: &candle_core::MetalStorage,
184        weight_l: &candle_core::Layout,
185    ) -> Result<(candle_core::MetalStorage, candle_core::Shape)> {
186        use candle_core::backend::BackendStorage;
187
188        if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
189            candle_core::bail!("Expected weight to have start offset 0, continuous");
190        }
191        if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
192            candle_core::bail!("Expected scales to have start offset 0, continuous");
193        }
194
195        let device = weight_s.device();
196        let encoder = device.command_encoder()?;
197        encoder.set_label("fp8-vector-dequant");
198
199        let num_elements = weight_l.shape().elem_count();
200        let out_shape = weight_l.shape().clone();
201
202        let output = device.new_buffer(num_elements, self.out_ty, "fp8-vector-dequant-output")?;
203
204        crate::metal_kernels::call_fp8_vector_dequant(
205            device.device(),
206            &encoder,
207            &crate::metal_kernels::Kernels::new(),
208            self.out_ty,
209            weight_s.buffer(),
210            scale_s.buffer(),
211            &output,
212            num_elements,
213        )
214        .map_err(candle_core::Error::wrap)?;
215
216        let newstorage =
217            candle_core::MetalStorage::new(output, device.clone(), num_elements, self.out_ty);
218        Ok((newstorage, out_shape))
219    }
220}
221
222/// FP8 vector dequantize.
223/// - Expects weight to be fp8
224/// - Expects inv_scales to be f32
225/// - weight * inv_scale = dequantized
226pub fn fp8_vector_dequantize(
227    weight: &Tensor,
228    inv_scales: &Tensor,
229    out_ty: DType,
230) -> Result<Tensor> {
231    inv_scales.apply_op2_no_bwd(weight, &Fp8VectorDequantize { out_ty })
232}
233
234/// CPU implementation of vector quantization
235fn cpu_vector_quantize<T: WithDType>(
236    input: &[T],
237    num_elements: usize,
238) -> candle_core::Result<(Vec<F8E4M3>, Vec<f32>)> {
239    let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
240
241    let weight = vec![F8E4M3::from_f32(0.0); num_elements];
242    let scale = vec![0f32; num_vectors];
243
244    (0..num_vectors).into_par_iter().for_each(|vector_idx| {
245        let weight_ptr = weight.as_ptr() as *mut F8E4M3;
246        let scale_ptr = scale.as_ptr() as *mut f32;
247
248        let vector_start = vector_idx * VECTOR_SIZE;
249        let vector_end = vector_start + VECTOR_SIZE.min(num_elements - vector_start);
250
251        // Find max absolute value in vector
252        let mut max_abs = 0f32;
253        for &input_val in &input[vector_start..vector_end] {
254            let val = input_val.to_f64() as f32;
255            let abs_val = val.abs();
256            if abs_val > max_abs {
257                max_abs = abs_val;
258            }
259        }
260
261        // Calculate scale
262        let vector_scale = if max_abs > 0.0 {
263            max_abs / 448.0
264        } else {
265            1e-12
266        };
267
268        // SAFETY: We know each thread will only update independent values!
269        unsafe {
270            *scale_ptr.wrapping_add(vector_idx) = vector_scale;
271        }
272
273        // Quantize values
274        for (idx, &input_val) in input[vector_start..vector_end].iter().enumerate() {
275            let global_idx = vector_start + idx;
276            let val = input_val.to_f64() as f32;
277            let scaled_val = (val / vector_scale).clamp(-448.0, 448.0);
278
279            // SAFETY: We know each thread will only update independent values!
280            unsafe {
281                *weight_ptr.wrapping_add(global_idx) = F8E4M3::from_f32(scaled_val);
282            }
283        }
284    });
285
286    Ok((weight, scale))
287}
288
289/// FP8 vector quantize for CPU
290fn cpu_fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
291    let num_elements = input.shape().elem_count();
292    let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
293
294    let (weight_data, scale_data) = match input.dtype() {
295        DType::F32 => {
296            let data = input.to_vec1::<f32>()?;
297            cpu_vector_quantize(&data, num_elements)?
298        }
299        DType::F16 => {
300            let data = input.to_vec1::<half::f16>()?;
301            cpu_vector_quantize(&data, num_elements)?
302        }
303        DType::BF16 => {
304            let data = input.to_vec1::<half::bf16>()?;
305            cpu_vector_quantize(&data, num_elements)?
306        }
307        other => candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}"),
308    };
309
310    // Create tensors from the raw data
311    let weight = Tensor::from_vec(weight_data, input.shape(), input.device())?;
312    let scale = Tensor::from_vec(scale_data, num_vectors, input.device())?;
313
314    Ok((weight, scale))
315}
316
317/// FP8 vector quantize.
318/// - Expects input to be f32, f16, or bf16
319/// - Returns a tuple of (quantized_weight, scales)
320/// - quantized_weight is fp8
321/// - scales is f32
322/// - Each scale corresponds to a vector of 128 elements
323pub fn fp8_vector_quantize(input: &Tensor) -> Result<(Tensor, Tensor)> {
324    // Check that tensor size is divisible by 128
325    let num_elements = input.shape().elem_count();
326    if num_elements % VECTOR_SIZE != 0 {
327        candle_core::bail!(
328            "Tensor size {} must be divisible by {} for vector FP8 quantization",
329            num_elements,
330            VECTOR_SIZE
331        );
332    }
333
334    // Check if we should use CPU implementation
335    if matches!(input.device(), candle_core::Device::Cpu) {
336        return cpu_fp8_vector_quantize(input);
337    }
338
339    #[cfg(feature = "cuda")]
340    {
341        use candle_core::{CudaStorage, Device, Storage};
342        use half::{bf16, f16};
343
344        use crate::{utils::slice_ptr, vector_fp8::ffi};
345
346        if matches!(input.device(), Device::Cuda(_)) {
347            if !ffi::HAVE_VECTOR_QUANT_KERNELS {
348                candle_core::bail!("Do not have vector FP8 quant kernels.");
349            }
350
351            let input_l = input.layout();
352            if input_l.start_offset() != 0 || !input_l.is_contiguous() {
353                candle_core::bail!("Expected input to have start offset 0, continuous");
354            }
355
356            let dev = match input.device() {
357                Device::Cuda(dev) => dev,
358                _ => unreachable!(),
359            };
360
361            let num_vectors = num_elements.div_ceil(VECTOR_SIZE);
362
363            // Allocate output buffers
364            let weight_output = dev.alloc_zeros::<F8E4M3>(num_elements)?;
365            let scale_output = dev.alloc_zeros::<f32>(num_vectors)?;
366
367            let (weight_ptr, _weight_guard) = slice_ptr(&weight_output, 0);
368            let (scale_ptr, _scale_guard) = slice_ptr(&scale_output, 0);
369
370            match input.dtype() {
371                DType::F32 => {
372                    let input_storage = input.storage_and_layout().0;
373                    let input_s = match &*input_storage {
374                        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f32>()?,
375                        _ => candle_core::bail!("Expected CUDA storage"),
376                    };
377                    let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
378                    unsafe {
379                        ffi::launch_quant_fp8_vector_kernel_f32(
380                            input_ptr as *const _,
381                            weight_ptr as *mut _,
382                            scale_ptr as *mut _,
383                            num_elements,
384                            dev.cuda_stream().cu_stream(),
385                        )
386                    };
387                }
388                DType::F16 => {
389                    let input_storage = input.storage_and_layout().0;
390                    let input_s = match &*input_storage {
391                        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<f16>()?,
392                        _ => candle_core::bail!("Expected CUDA storage"),
393                    };
394                    let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
395                    unsafe {
396                        ffi::launch_quant_fp8_vector_kernel_f16(
397                            input_ptr as *const _,
398                            weight_ptr as *mut _,
399                            scale_ptr as *mut _,
400                            num_elements,
401                            dev.cuda_stream().cu_stream(),
402                        )
403                    };
404                }
405                DType::BF16 => {
406                    let input_storage = input.storage_and_layout().0;
407                    let input_s = match &*input_storage {
408                        Storage::Cuda(cuda_storage) => cuda_storage.as_cuda_slice::<bf16>()?,
409                        _ => candle_core::bail!("Expected CUDA storage"),
410                    };
411                    let (input_ptr, _input_guard) = slice_ptr(input_s, input_l.start_offset());
412                    unsafe {
413                        ffi::launch_quant_fp8_vector_kernel_bf16(
414                            input_ptr as *const _,
415                            weight_ptr as *mut _,
416                            scale_ptr as *mut _,
417                            num_elements,
418                            dev.cuda_stream().cu_stream(),
419                        )
420                    };
421                }
422                other => {
423                    candle_core::bail!("unexpected input type for fp8 vector quant: {other:?}")
424                }
425            }
426
427            // Drop guards before moving the buffers
428            drop(_weight_guard);
429            drop(_scale_guard);
430
431            // Create weight tensor by wrapping the CUDA storage
432            let weight_storage = CudaStorage::wrap_cuda_slice(weight_output, dev.clone());
433            let weight = Tensor::from((Storage::Cuda(weight_storage), input.shape().clone()));
434
435            // Create scale tensor
436            let scale_storage = CudaStorage::wrap_cuda_slice(scale_output, dev.clone());
437            let scale = Tensor::from((
438                Storage::Cuda(scale_storage),
439                candle_core::Shape::from_dims(&[num_vectors]),
440            ));
441
442            Ok((weight, scale))
443        } else {
444            candle_core::bail!("Expected CUDA device.");
445        }
446    }
447
448    #[cfg(not(feature = "cuda"))]
449    {
450        candle_core::bail!("FP8 vector quantization on non-CPU devices requires CUDA feature");
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use candle_core::{DType, Device, Result, Tensor};
458
459    #[test]
460    fn test_fp8_vector_dequant() -> Result<()> {
461        let dev = &Device::Cpu;
462        let num_elements = 256; // 2 vectors of 128 elements
463        let weight = Tensor::ones(num_elements, DType::F8E4M3, dev)?;
464        let scales = Tensor::new(&[2.0f32, 3.0f32], dev)?; // 2 scales for 2 vectors
465
466        let dequant = fp8_vector_dequantize(&weight, &scales, DType::F32)?;
467        let res = dequant.to_vec1::<f32>()?;
468
469        // First 128 elements should be 2.0, next 128 should be 3.0
470        for &val in &res[0..128] {
471            assert_eq!(val, 2.0);
472        }
473        for &val in &res[128..256] {
474            assert_eq!(val, 3.0);
475        }
476
477        Ok(())
478    }
479
480    #[test]
481    fn test_fp8_vector_quant_cpu() -> Result<()> {
482        let dev = &Device::Cpu;
483
484        // Create test input with 256 elements (2 vectors)
485        let input = Tensor::randn(0f32, 2f32, 256, dev)?;
486
487        // Quantize
488        let (quantized, scales) = fp8_vector_quantize(&input)?;
489
490        // Verify shapes
491        assert_eq!(quantized.shape(), input.shape());
492        assert_eq!(scales.dims1()?, 2); // 256/128 = 2 vectors
493
494        // Dequantize
495        let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
496
497        // Check that shapes match
498        assert_eq!(dequantized.shape(), input.shape());
499
500        // The values won't be exactly the same due to quantization loss,
501        // but they should be reasonably close
502        let input_vec = input.to_vec1::<f32>()?;
503        let dequant_vec = dequantized.to_vec1::<f32>()?;
504
505        let mut max_error = 0f32;
506        for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
507            let error = (val_in - val_out).abs();
508            max_error = max_error.max(error);
509        }
510
511        // FP8 E4M3 has limited precision, so we expect some error
512        assert!(max_error < 0.27, "Max error {max_error} is too large");
513
514        Ok(())
515    }
516
517    #[cfg(feature = "cuda")]
518    #[test]
519    fn test_fp8_vector_quant_dequant_roundtrip() -> Result<()> {
520        let dev = &Device::new_cuda(0)?;
521
522        // Create test input with 256 elements (2 vectors)
523        let input = Tensor::randn(0f32, 2f32, 256, dev)?;
524
525        // Quantize
526        let (quantized, scales) = fp8_vector_quantize(&input)?;
527
528        // Verify shapes
529        assert_eq!(quantized.shape(), input.shape());
530        assert_eq!(scales.dims1()?, 2); // 256/128 = 2 vectors
531
532        // Dequantize
533        let dequantized = fp8_vector_dequantize(&quantized, &scales, input.dtype())?;
534
535        // Check that shapes match
536        assert_eq!(dequantized.shape(), input.shape());
537
538        // The values won't be exactly the same due to quantization loss,
539        // but they should be reasonably close
540        let input_vec = input.to_vec1::<f32>()?;
541        let dequant_vec = dequantized.to_vec1::<f32>()?;
542
543        let mut max_error = 0f32;
544        for (val_in, val_out) in input_vec.iter().zip(dequant_vec.iter()) {
545            let error = (val_in - val_out).abs();
546            max_error = max_error.max(error);
547        }
548
549        // FP8 E4M3 has limited precision, so we expect some error
550        assert!(max_error < 0.24, "Max error {} is too large", max_error);
551
552        Ok(())
553    }
554
555    #[cfg(feature = "cuda")]
556    #[test]
557    fn test_fp8_vector_cpu_cuda_equivalence() -> Result<()> {
558        let cpu_dev = &Device::Cpu;
559        let cuda_dev = &Device::new_cuda(0)?;
560
561        // Create the same input data on both devices
562        let input_data: Vec<f32> = (0..256).map(|i| ((i as f32) - 128.0) / 10.0).collect();
563        let cpu_input = Tensor::from_vec(input_data.clone(), 256, cpu_dev)?;
564        let cuda_input = Tensor::from_vec(input_data, 256, cuda_dev)?;
565
566        // Quantize on CPU
567        let (cpu_quantized, cpu_scales) = fp8_vector_quantize(&cpu_input)?;
568
569        // Quantize on CUDA
570        let (cuda_quantized, cuda_scales) = fp8_vector_quantize(&cuda_input)?;
571
572        // Move CUDA results to CPU for comparison
573        let cuda_quantized_cpu = cuda_quantized.to_device(cpu_dev)?;
574        let cuda_scales_cpu = cuda_scales.to_device(cpu_dev)?;
575
576        // Compare quantized weights
577        let cpu_quant_vec = cpu_quantized.to_vec1::<F8E4M3>()?;
578        let cuda_quant_vec = cuda_quantized_cpu.to_vec1::<F8E4M3>()?;
579
580        assert_eq!(cpu_quant_vec.len(), cuda_quant_vec.len());
581
582        let mut num_differences = 0;
583        for (i, (cpu_val, cuda_val)) in cpu_quant_vec.iter().zip(cuda_quant_vec.iter()).enumerate()
584        {
585            if cpu_val.to_f32() != cuda_val.to_f32() {
586                // Allow small differences due to floating point precision
587                let diff = (cpu_val.to_f32() - cuda_val.to_f32()).abs();
588                if diff > 1e-6 {
589                    num_differences += 1;
590                    if num_differences < 10 {
591                        println!(
592                            "Difference at index {}: CPU={}, CUDA={}, diff={}",
593                            i,
594                            cpu_val.to_f32(),
595                            cuda_val.to_f32(),
596                            diff
597                        );
598                    }
599                }
600            }
601        }
602
603        // FP8 quantization should be deterministic, so we expect very few differences
604        assert!(
605            num_differences < 5,
606            "Too many differences between CPU and CUDA quantization: {}",
607            num_differences
608        );
609
610        // Compare scales
611        let cpu_scales_vec = cpu_scales.to_vec1::<f32>()?;
612        let cuda_scales_vec = cuda_scales_cpu.to_vec1::<f32>()?;
613
614        assert_eq!(cpu_scales_vec.len(), cuda_scales_vec.len());
615
616        for (i, (cpu_scale, cuda_scale)) in cpu_scales_vec
617            .iter()
618            .zip(cuda_scales_vec.iter())
619            .enumerate()
620        {
621            let scale_diff = (cpu_scale - cuda_scale).abs();
622            assert!(
623                scale_diff < 1e-6,
624                "Scale difference at index {}: CPU={}, CUDA={}, diff={}",
625                i,
626                cpu_scale,
627                cuda_scale,
628                scale_diff
629            );
630        }
631
632        // Also test that dequantization gives the same results
633        let cpu_dequant = fp8_vector_dequantize(&cpu_quantized, &cpu_scales, DType::F32)?;
634        let cuda_dequant =
635            fp8_vector_dequantize(&cuda_quantized_cpu, &cuda_scales_cpu, DType::F32)?;
636
637        let cpu_dequant_vec = cpu_dequant.to_vec1::<f32>()?;
638        let cuda_dequant_vec = cuda_dequant.to_vec1::<f32>()?;
639
640        let mut max_dequant_diff = 0f32;
641        for (cpu_val, cuda_val) in cpu_dequant_vec.iter().zip(cuda_dequant_vec.iter()) {
642            let diff = (cpu_val - cuda_val).abs();
643            max_dequant_diff = max_dequant_diff.max(diff);
644        }
645
646        assert!(
647            max_dequant_diff < 1e-5,
648            "Max dequantization difference too large: {}",
649            max_dequant_diff
650        );
651
652        Ok(())
653    }
654}