mistralrs_quant/utils/
ops.rs

1use candle_core::{
2    backend::BackendStorage, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout, Result, Shape,
3    Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::ops::{BitOr, Shl};
8
9#[cfg(feature = "cuda")]
10use crate::utils::ffi;
11#[cfg(feature = "cuda")]
12use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr};
13#[cfg(feature = "cuda")]
14use std::ffi::c_void;
15
16struct BitWiseOr;
17
18impl BitWiseOr {
19    fn bitwise<T: WithDType + BitOr<Output = T>>(&self, vs1: &[T], vs2: &[T]) -> Vec<T> {
20        vs1.into_par_iter()
21            .zip_eq(vs2)
22            .map(|(v1, v2)| *v1 | *v2)
23            .collect()
24    }
25}
26
27impl CustomOp2 for BitWiseOr {
28    fn name(&self) -> &'static str {
29        "bitwise-or"
30    }
31
32    fn cpu_fwd(
33        &self,
34        s1: &CpuStorage,
35        l1: &Layout,
36        s2: &CpuStorage,
37        l2: &Layout,
38    ) -> Result<(CpuStorage, Shape)> {
39        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
40            return Err(Error::ShapeMismatchBinaryOp {
41                lhs: l1.shape().clone(),
42                rhs: l2.shape().clone(),
43                op: "bitwise-or",
44            });
45        }
46        if s1.dtype() != s2.dtype() {
47            return Err(Error::DTypeMismatchBinaryOp {
48                lhs: s1.dtype(),
49                rhs: s2.dtype(),
50                op: "bitwise-or",
51            });
52        }
53        match s1 {
54            CpuStorage::U8(vs1) => {
55                let vs1 = match l1.contiguous_offsets() {
56                    Some((start, end)) => &vs1[start..end],
57                    None => candle_core::bail!("Input tensor s1 must be contiguous"),
58                };
59                let vs2 = s2.as_slice::<u8>()?;
60                let vs2 = match l2.contiguous_offsets() {
61                    Some((start, end)) => &vs2[start..end],
62                    None => candle_core::bail!("Input tensor s2 must be contiguous"),
63                };
64                if vs1.len() != vs2.len() {
65                    candle_core::bail!("Input tensors must have the same number of elements");
66                };
67                let result = self.bitwise(vs1, vs2);
68                let result = CpuStorage::U8(result);
69                Ok((result, l1.shape().clone()))
70            }
71            CpuStorage::I16(vs1) => {
72                let vs2 = &s2.as_slice::<i16>().unwrap();
73                let result = self.bitwise(vs1, vs2);
74                let result = CpuStorage::I16(result);
75                Ok((result, l1.shape().clone()))
76            }
77            CpuStorage::U32(vs1) => {
78                let vs2 = &s2.as_slice::<u32>().unwrap();
79                let result = self.bitwise(vs1, vs2);
80                let result = CpuStorage::U32(result);
81                Ok((result, l1.shape().clone()))
82            }
83            CpuStorage::I64(vs1) => {
84                let vs2 = &s2.as_slice::<i64>().unwrap();
85                let result = self.bitwise(vs1, vs2);
86                let result = CpuStorage::I64(result);
87                Ok((result, l1.shape().clone()))
88            }
89            CpuStorage::I32(vs1) => {
90                let vs2 = &s2.as_slice::<i32>().unwrap();
91                let result = self.bitwise(vs1, vs2);
92                let result = CpuStorage::I32(result);
93                Ok((result, l1.shape().clone()))
94            }
95            CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or")),
96            CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or")),
97            CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or")),
98            CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or")),
99            CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise-or")),
100        }
101    }
102    #[cfg(feature = "cuda")]
103    fn cuda_fwd(
104        &self,
105        s1: &CudaStorage,
106        l1: &Layout,
107        s2: &CudaStorage,
108        l2: &Layout,
109    ) -> Result<(CudaStorage, Shape)> {
110        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
111            return Err(Error::ShapeMismatchBinaryOp {
112                lhs: l1.shape().clone(),
113                rhs: l2.shape().clone(),
114                op: "bitwise-or",
115            });
116        }
117        if s1.dtype() != s2.dtype() {
118            return Err(Error::DTypeMismatchBinaryOp {
119                lhs: s1.dtype(),
120                rhs: s2.dtype(),
121                op: "bitwise-or",
122            });
123        }
124        let dev = s1.device().clone();
125        let (d_in1_ptr, d_in2_ptr, elem_count) = match s1.dtype() {
126            DType::U8 => {
127                let d_in1_ptr = *s1
128                    .as_cuda_slice::<u8>()?
129                    .slice(l1.start_offset()..)
130                    .device_ptr() as *const c_void;
131                let d_in2_ptr = *s2
132                    .as_cuda_slice::<u8>()?
133                    .slice(l2.start_offset()..)
134                    .device_ptr() as *const c_void;
135                let elem_count = l1.shape().elem_count();
136                (d_in1_ptr, d_in2_ptr, elem_count)
137            }
138            DType::I16 => {
139                return Err(Error::UnsupportedDTypeForOp(DType::I16, "bitwise-or"));
140            }
141            DType::U32 => {
142                return Err(Error::UnsupportedDTypeForOp(DType::U32, "bitwise-or"));
143            }
144            DType::I64 => {
145                return Err(Error::UnsupportedDTypeForOp(DType::I64, "bitwise-or"));
146            }
147            DType::I32 => {
148                let d_in1_ptr = *s1
149                    .as_cuda_slice::<i32>()?
150                    .slice(l1.start_offset()..)
151                    .device_ptr() as *const c_void;
152                let d_in2_ptr = *s2
153                    .as_cuda_slice::<i32>()?
154                    .slice(l2.start_offset()..)
155                    .device_ptr() as *const c_void;
156                let elem_count = l1.shape().elem_count();
157                (d_in1_ptr, d_in2_ptr, elem_count)
158            }
159            DType::BF16 => {
160                return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or"));
161            }
162            DType::F16 => {
163                return Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or"));
164            }
165            DType::F32 => {
166                return Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or"));
167            }
168            DType::F64 => {
169                return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or"));
170            }
171            DType::F8E4M3 => {
172                return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise-or"));
173            }
174        };
175        let dst = match s1.dtype() {
176            DType::U8 => {
177                let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
178                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
179                unsafe {
180                    ffi::mq_bitwise_or_u8(
181                        d_in1_ptr,
182                        d_in2_ptr,
183                        d_out_ptr,
184                        u32::try_from(elem_count)?,
185                    )
186                };
187                CudaStorage::wrap_cuda_slice(d_out, dev)
188            }
189            DType::I32 => {
190                let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
191                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
192                unsafe {
193                    ffi::mq_bitwise_or_i32(
194                        d_in1_ptr,
195                        d_in2_ptr,
196                        d_out_ptr,
197                        u32::try_from(elem_count)?,
198                    )
199                };
200                CudaStorage::wrap_cuda_slice(d_out, dev)
201            }
202            _ => unreachable!(),
203        };
204        Ok((dst, l1.shape().clone()))
205    }
206    #[cfg(feature = "metal")]
207    fn metal_fwd(
208        &self,
209        s1: &candle_core::MetalStorage,
210        l1: &Layout,
211        s2: &candle_core::MetalStorage,
212        l2: &Layout,
213    ) -> Result<(candle_core::MetalStorage, Shape)> {
214        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
215            return Err(Error::ShapeMismatchBinaryOp {
216                lhs: l1.shape().clone(),
217                rhs: l2.shape().clone(),
218                op: "bitwise-or",
219            });
220        }
221        if s1.dtype() != s2.dtype() {
222            return Err(Error::DTypeMismatchBinaryOp {
223                lhs: s1.dtype(),
224                rhs: s2.dtype(),
225                op: "bitwise-or",
226            });
227        }
228        if !l1.is_contiguous() {
229            candle_core::bail!("Input tensor s1 must be contiguous");
230        }
231        if !l2.is_contiguous() {
232            candle_core::bail!("Input tensor s2 must be contiguous");
233        }
234
235        let command_buffer = s1.device().command_buffer()?;
236        command_buffer.set_label("bitwise-or");
237
238        let device = s1.device();
239
240        let out_shape = l1.shape().clone();
241
242        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-or")?;
243
244        crate::metal_kernels::call_bitwise_or(
245            device.device(),
246            &command_buffer,
247            &crate::metal_kernels::Kernels::new(),
248            s1.dtype(),
249            s1.buffer(),
250            s2.buffer(),
251            l1.start_offset(),
252            l2.start_offset(),
253            out_shape.elem_count(),
254            &output,
255        )
256        .map_err(candle_core::Error::wrap)?;
257
258        let newstorage = candle_core::MetalStorage::new(
259            output,
260            device.clone(),
261            out_shape.elem_count(),
262            s1.dtype(),
263        );
264        Ok((newstorage, out_shape))
265    }
266}
267
268#[allow(dead_code)]
269pub trait BitWiseOp {
270    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
271}
272
273impl BitWiseOp for Tensor {
274    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
275        self.apply_op2_no_bwd(rhs, &BitWiseOr)
276    }
277}
278struct Leftshift(usize);
279
280impl Leftshift {
281    fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
282        let offset = T::from_f64(self.0 as f64);
283        vs.into_par_iter().map(|v| *v << offset).collect()
284    }
285}
286
287impl CustomOp1 for Leftshift {
288    fn name(&self) -> &'static str {
289        "left"
290    }
291
292    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
293        if !l1.is_contiguous() {
294            candle_core::bail!("Input tensor s1 must be contiguous");
295        }
296        match s1 {
297            CpuStorage::U8(vs1) => {
298                let result = self.leftshift(vs1);
299                let result = CpuStorage::U8(result);
300                Ok((result, l1.shape().clone()))
301            }
302            CpuStorage::I16(vs1) => {
303                let result = self.leftshift(vs1);
304                let result = CpuStorage::I16(result);
305                Ok((result, l1.shape().clone()))
306            }
307            CpuStorage::U32(vs1) => {
308                let result = self.leftshift(vs1);
309                let result = CpuStorage::U32(result);
310                Ok((result, l1.shape().clone()))
311            }
312            CpuStorage::I64(vs1) => {
313                let result = self.leftshift(vs1);
314                let result = CpuStorage::I64(result);
315                Ok((result, l1.shape().clone()))
316            }
317            CpuStorage::I32(vs1) => {
318                let result = self.leftshift(vs1);
319                let result = CpuStorage::I32(result);
320                Ok((result, l1.shape().clone()))
321            }
322            CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshifr")),
323            CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")),
324            CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")),
325            CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshifr")),
326            CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshifr")),
327        }
328    }
329    #[cfg(feature = "cuda")]
330    fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
331        if !l1.is_contiguous() {
332            candle_core::bail!("Input tensor s1 must be contiguous");
333        }
334        let dev = s1.device().clone();
335        let (d_in1_ptr, elem_count) = match s1.dtype() {
336            DType::U8 => {
337                let d_in1_ptr = *s1
338                    .as_cuda_slice::<u8>()?
339                    .slice(l1.start_offset()..)
340                    .device_ptr() as *const c_void;
341                let elem_count = l1.shape().elem_count();
342                (d_in1_ptr, elem_count)
343            }
344            DType::I16 => {
345                return Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshift"));
346            }
347            DType::U32 => {
348                return Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshift"));
349            }
350            DType::I64 => {
351                return Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshift"));
352            }
353            DType::I32 => {
354                let d_in1_ptr = *s1
355                    .as_cuda_slice::<i32>()?
356                    .slice(l1.start_offset()..)
357                    .device_ptr() as *const c_void;
358                let elem_count = l1.shape().elem_count();
359                (d_in1_ptr, elem_count)
360            }
361            DType::BF16 => {
362                return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift"));
363            }
364            DType::F16 => {
365                return Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshift"));
366            }
367            DType::F32 => {
368                return Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshift"));
369            }
370            DType::F64 => {
371                return Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshift"));
372            }
373            DType::F8E4M3 => {
374                return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshift"));
375            }
376        };
377        let dst = match s1.dtype() {
378            DType::U8 => {
379                let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
380                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
381                unsafe {
382                    ffi::mq_leftshift_u8(
383                        d_in1_ptr,
384                        d_out_ptr,
385                        u32::try_from(elem_count)?,
386                        self.0 as i32,
387                    )
388                };
389                CudaStorage::wrap_cuda_slice(d_out, dev)
390            }
391            DType::I32 => {
392                let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
393                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
394                unsafe {
395                    ffi::mq_leftshift_i32(
396                        d_in1_ptr,
397                        d_out_ptr,
398                        u32::try_from(elem_count)?,
399                        self.0 as i32,
400                    )
401                };
402                CudaStorage::wrap_cuda_slice(d_out, dev)
403            }
404            _ => unreachable!(),
405        };
406        Ok((dst, l1.shape().clone()))
407    }
408    #[cfg(feature = "metal")]
409    fn metal_fwd(
410        &self,
411        s1: &candle_core::MetalStorage,
412        l1: &Layout,
413    ) -> Result<(candle_core::MetalStorage, Shape)> {
414        if !l1.is_contiguous() {
415            candle_core::bail!("Input tensor s1 must be contiguous");
416        }
417
418        let command_buffer = s1.device().command_buffer()?;
419        command_buffer.set_label("bitwise-leftshift");
420
421        let device = s1.device();
422
423        let out_shape = l1.shape().clone();
424
425        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
426
427        crate::metal_kernels::call_bitwise_leftshift(
428            device.device(),
429            &command_buffer,
430            &crate::metal_kernels::Kernels::new(),
431            s1.dtype(),
432            s1.buffer(),
433            l1.start_offset(),
434            self.0 as u32,
435            out_shape.elem_count(),
436            &output,
437        )
438        .map_err(candle_core::Error::wrap)?;
439
440        let newstorage = candle_core::MetalStorage::new(
441            output,
442            device.clone(),
443            out_shape.elem_count(),
444            s1.dtype(),
445        );
446        Ok((newstorage, out_shape))
447    }
448}
449
450#[allow(dead_code)]
451pub trait LeftshiftOp {
452    fn leftshift(&self, n: usize) -> Result<Tensor>;
453}
454
455impl LeftshiftOp for Tensor {
456    fn leftshift(&self, n: usize) -> Result<Tensor> {
457        self.apply_op1_no_bwd(&Leftshift(n))
458    }
459}
460
461mod tests {
462    #[test]
463    fn test_bitwise_or_cpu() {
464        use crate::utils::ops::BitWiseOp;
465        use candle_core::Tensor;
466        let device = candle_core::Device::Cpu;
467        let a =
468            Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
469        let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
470        let c = a.bitwise_or(&b).unwrap().to_vec2::<i32>().unwrap();
471        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
472    }
473
474    #[cfg(feature = "cuda")]
475    #[test]
476    fn test_bitwise_or_cuda() {
477        use crate::utils::ops::BitWiseOp;
478        use candle_core::Tensor;
479        let device = candle_core::Device::new_cuda(0).unwrap();
480        let a =
481            Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
482        let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
483        let c = a.bitwise_or(&b).unwrap().to_vec2::<i32>().unwrap();
484        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
485    }
486
487    #[cfg(feature = "metal")]
488    #[test]
489    fn test_bitwise_or_metal() {
490        use crate::utils::ops::BitWiseOp;
491        use candle_core::Tensor;
492        let device = candle_core::Device::new_metal(0).unwrap();
493        let a =
494            Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
495        let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
496        let c = a.bitwise_or(&b).unwrap().to_vec2::<i32>().unwrap();
497        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
498    }
499
500    #[test]
501    fn test_leftshift_cpu() {
502        use crate::utils::ops::LeftshiftOp;
503        use candle_core::Tensor;
504        let device = candle_core::Device::Cpu;
505        let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
506        let c = a.leftshift(2).unwrap().to_vec2::<i32>().unwrap();
507        assert_eq!(c, [[4, 8], [12, 16], [20, 24]]);
508    }
509
510    #[cfg(feature = "cuda")]
511    #[test]
512    fn test_leftshift_cuda() {
513        use crate::utils::ops::LeftshiftOp;
514        use candle_core::Tensor;
515        let device = candle_core::Device::new_cuda(0).unwrap();
516        let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
517        let c = a.leftshift(2).unwrap().to_vec2::<i32>().unwrap();
518        assert_eq!(c, [[4, 8], [12, 16], [20, 24]]);
519    }
520
521    #[cfg(feature = "metal")]
522    #[test]
523    fn test_leftshift_metal() {
524        use crate::utils::ops::LeftshiftOp;
525        use candle_core::Tensor;
526        let device = candle_core::Device::new_metal(0).unwrap();
527        let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
528        let c = a.leftshift(2).unwrap().to_vec2::<i32>().unwrap();
529        assert_eq!(c, [[4, 8], [12, 16], [20, 24]]);
530    }
531
532    #[cfg(feature = "cuda")]
533    #[test]
534    fn test_bitwise_or_and_leftshift_cuda() {
535        use crate::utils::{ops::BitWiseOp, LeftshiftOp};
536        use candle_core::Tensor;
537        let device = candle_core::Device::new_cuda(0).unwrap();
538        let a = Tensor::from_vec(vec![0b00001111u8], (1,), &device).unwrap();
539        let b = Tensor::from_vec(vec![0b00001111u8], (1,), &device).unwrap();
540        let c = a
541            .leftshift(4)
542            .unwrap()
543            .bitwise_or(&b)
544            .unwrap()
545            .to_vec1::<u8>()
546            .unwrap();
547        let av = a.to_vec1::<u8>().unwrap();
548        let bv = b.to_vec1::<u8>().unwrap();
549        assert_eq!(av, [0b00001111]);
550        assert_eq!(bv, [0b00001111]);
551        assert_eq!(c, [0b11111111]);
552    }
553
554    #[test]
555    fn test_bitpack_8bit_cpu() {
556        use crate::HqqBits;
557        use candle_core::{Device, Tensor};
558        let bits = HqqBits::Eight;
559        let device = Device::Cpu;
560        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
561        let c = bits.bitpack_type()(wq.clone())
562            .unwrap()
563            .to_vec2::<u8>()
564            .unwrap();
565        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
566    }
567
568    #[cfg(feature = "cuda")]
569    #[test]
570    fn test_bitpack_8bit_cuda() {
571        use crate::HqqBits;
572        use candle_core::DType;
573        use candle_core::{Device, Tensor};
574        let bits = HqqBits::Eight;
575        let device = Device::new_cuda(0).unwrap();
576        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
577        let c = bits.bitpack_type()(wq.clone())
578            .unwrap()
579            .to_dtype(DType::U8)
580            .unwrap()
581            .to_vec2::<u8>()
582            .unwrap();
583        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
584    }
585
586    #[cfg(feature = "metal")]
587    #[test]
588    fn test_bitpack_8bit_metal() {
589        use crate::HqqBits;
590        use candle_core::{Device, Tensor};
591        let bits = HqqBits::Eight;
592        let device = Device::new_metal(0).unwrap();
593        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
594        let c = bits.bitpack_type()(wq.clone())
595            .unwrap()
596            .to_vec2::<u8>()
597            .unwrap();
598        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
599    }
600
601    #[test]
602    fn test_bitpack_4bit() {
603        use crate::HqqBits;
604        use candle_core::{Device, Tensor};
605        let bits = HqqBits::Four;
606        let device = Device::Cpu;
607        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
608        let c = bits.bitpack_type()(wq.clone())
609            .unwrap()
610            .to_vec2::<u8>()
611            .unwrap();
612        assert_eq!(c, [[19, 36]]);
613    }
614
615    #[cfg(feature = "cuda")]
616    #[test]
617    fn test_bitpack_4bit_cuda() {
618        use crate::HqqBits;
619        use candle_core::{Device, Tensor};
620        let bits = HqqBits::Four;
621        let device = Device::new_cuda(0).unwrap();
622        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
623        let c = bits.bitpack_type()(wq.clone())
624            .unwrap()
625            .to_vec2::<u8>()
626            .unwrap();
627        assert_eq!(c, [[19, 36]]);
628    }
629
630    #[cfg(feature = "metal")]
631    #[test]
632    fn test_bitpack_4bit_metal() {
633        use crate::HqqBits;
634        use candle_core::{Device, Tensor};
635        let bits = HqqBits::Four;
636        let device = Device::new_metal(0).unwrap();
637        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
638        let c = bits.bitpack_type()(wq.clone())
639            .unwrap()
640            .to_vec2::<u8>()
641            .unwrap();
642        assert_eq!(c, [[19, 36]]);
643    }
644}