1use candle_core::{
2 backend::BackendStorage, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout, Result, Shape,
3 Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::ops::{BitOr, Shl};
8
9#[cfg(feature = "cuda")]
10use crate::utils::ffi;
11#[cfg(feature = "cuda")]
12use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr};
13#[cfg(feature = "cuda")]
14use std::ffi::c_void;
15
16struct BitWiseOr;
17
18impl BitWiseOr {
19 fn bitwise<T: WithDType + BitOr<Output = T>>(&self, vs1: &[T], vs2: &[T]) -> Vec<T> {
20 vs1.into_par_iter()
21 .zip_eq(vs2)
22 .map(|(v1, v2)| *v1 | *v2)
23 .collect()
24 }
25}
26
27impl CustomOp2 for BitWiseOr {
28 fn name(&self) -> &'static str {
29 "bitwise-or"
30 }
31
32 fn cpu_fwd(
33 &self,
34 s1: &CpuStorage,
35 l1: &Layout,
36 s2: &CpuStorage,
37 l2: &Layout,
38 ) -> Result<(CpuStorage, Shape)> {
39 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
40 return Err(Error::ShapeMismatchBinaryOp {
41 lhs: l1.shape().clone(),
42 rhs: l2.shape().clone(),
43 op: "bitwise-or",
44 });
45 }
46 if s1.dtype() != s2.dtype() {
47 return Err(Error::DTypeMismatchBinaryOp {
48 lhs: s1.dtype(),
49 rhs: s2.dtype(),
50 op: "bitwise-or",
51 });
52 }
53 match s1 {
54 CpuStorage::U8(vs1) => {
55 let vs1 = match l1.contiguous_offsets() {
56 Some((start, end)) => &vs1[start..end],
57 None => candle_core::bail!("Input tensor s1 must be contiguous"),
58 };
59 let vs2 = s2.as_slice::<u8>()?;
60 let vs2 = match l2.contiguous_offsets() {
61 Some((start, end)) => &vs2[start..end],
62 None => candle_core::bail!("Input tensor s2 must be contiguous"),
63 };
64 if vs1.len() != vs2.len() {
65 candle_core::bail!("Input tensors must have the same number of elements");
66 };
67 let result = self.bitwise(vs1, vs2);
68 let result = CpuStorage::U8(result);
69 Ok((result, l1.shape().clone()))
70 }
71 CpuStorage::I16(vs1) => {
72 let vs2 = &s2.as_slice::<i16>().unwrap();
73 let result = self.bitwise(vs1, vs2);
74 let result = CpuStorage::I16(result);
75 Ok((result, l1.shape().clone()))
76 }
77 CpuStorage::U32(vs1) => {
78 let vs2 = &s2.as_slice::<u32>().unwrap();
79 let result = self.bitwise(vs1, vs2);
80 let result = CpuStorage::U32(result);
81 Ok((result, l1.shape().clone()))
82 }
83 CpuStorage::I64(vs1) => {
84 let vs2 = &s2.as_slice::<i64>().unwrap();
85 let result = self.bitwise(vs1, vs2);
86 let result = CpuStorage::I64(result);
87 Ok((result, l1.shape().clone()))
88 }
89 CpuStorage::I32(vs1) => {
90 let vs2 = &s2.as_slice::<i32>().unwrap();
91 let result = self.bitwise(vs1, vs2);
92 let result = CpuStorage::I32(result);
93 Ok((result, l1.shape().clone()))
94 }
95 CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or")),
96 CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or")),
97 CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or")),
98 CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or")),
99 CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise-or")),
100 }
101 }
102 #[cfg(feature = "cuda")]
103 fn cuda_fwd(
104 &self,
105 s1: &CudaStorage,
106 l1: &Layout,
107 s2: &CudaStorage,
108 l2: &Layout,
109 ) -> Result<(CudaStorage, Shape)> {
110 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
111 return Err(Error::ShapeMismatchBinaryOp {
112 lhs: l1.shape().clone(),
113 rhs: l2.shape().clone(),
114 op: "bitwise-or",
115 });
116 }
117 if s1.dtype() != s2.dtype() {
118 return Err(Error::DTypeMismatchBinaryOp {
119 lhs: s1.dtype(),
120 rhs: s2.dtype(),
121 op: "bitwise-or",
122 });
123 }
124 let dev = s1.device().clone();
125 let (d_in1_ptr, d_in2_ptr, elem_count) = match s1.dtype() {
126 DType::U8 => {
127 let d_in1_ptr = *s1
128 .as_cuda_slice::<u8>()?
129 .slice(l1.start_offset()..)
130 .device_ptr() as *const c_void;
131 let d_in2_ptr = *s2
132 .as_cuda_slice::<u8>()?
133 .slice(l2.start_offset()..)
134 .device_ptr() as *const c_void;
135 let elem_count = l1.shape().elem_count();
136 (d_in1_ptr, d_in2_ptr, elem_count)
137 }
138 DType::I16 => {
139 return Err(Error::UnsupportedDTypeForOp(DType::I16, "bitwise-or"));
140 }
141 DType::U32 => {
142 return Err(Error::UnsupportedDTypeForOp(DType::U32, "bitwise-or"));
143 }
144 DType::I64 => {
145 return Err(Error::UnsupportedDTypeForOp(DType::I64, "bitwise-or"));
146 }
147 DType::I32 => {
148 let d_in1_ptr = *s1
149 .as_cuda_slice::<i32>()?
150 .slice(l1.start_offset()..)
151 .device_ptr() as *const c_void;
152 let d_in2_ptr = *s2
153 .as_cuda_slice::<i32>()?
154 .slice(l2.start_offset()..)
155 .device_ptr() as *const c_void;
156 let elem_count = l1.shape().elem_count();
157 (d_in1_ptr, d_in2_ptr, elem_count)
158 }
159 DType::BF16 => {
160 return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or"));
161 }
162 DType::F16 => {
163 return Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or"));
164 }
165 DType::F32 => {
166 return Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or"));
167 }
168 DType::F64 => {
169 return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or"));
170 }
171 DType::F8E4M3 => {
172 return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise-or"));
173 }
174 };
175 let dst = match s1.dtype() {
176 DType::U8 => {
177 let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
178 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
179 unsafe {
180 ffi::mq_bitwise_or_u8(
181 d_in1_ptr,
182 d_in2_ptr,
183 d_out_ptr,
184 u32::try_from(elem_count)?,
185 )
186 };
187 CudaStorage::wrap_cuda_slice(d_out, dev)
188 }
189 DType::I32 => {
190 let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
191 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
192 unsafe {
193 ffi::mq_bitwise_or_i32(
194 d_in1_ptr,
195 d_in2_ptr,
196 d_out_ptr,
197 u32::try_from(elem_count)?,
198 )
199 };
200 CudaStorage::wrap_cuda_slice(d_out, dev)
201 }
202 _ => unreachable!(),
203 };
204 Ok((dst, l1.shape().clone()))
205 }
206 #[cfg(feature = "metal")]
207 fn metal_fwd(
208 &self,
209 s1: &candle_core::MetalStorage,
210 l1: &Layout,
211 s2: &candle_core::MetalStorage,
212 l2: &Layout,
213 ) -> Result<(candle_core::MetalStorage, Shape)> {
214 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
215 return Err(Error::ShapeMismatchBinaryOp {
216 lhs: l1.shape().clone(),
217 rhs: l2.shape().clone(),
218 op: "bitwise-or",
219 });
220 }
221 if s1.dtype() != s2.dtype() {
222 return Err(Error::DTypeMismatchBinaryOp {
223 lhs: s1.dtype(),
224 rhs: s2.dtype(),
225 op: "bitwise-or",
226 });
227 }
228 if !l1.is_contiguous() {
229 candle_core::bail!("Input tensor s1 must be contiguous");
230 }
231 if !l2.is_contiguous() {
232 candle_core::bail!("Input tensor s2 must be contiguous");
233 }
234
235 let command_buffer = s1.device().command_buffer()?;
236 command_buffer.set_label("bitwise-or");
237
238 let device = s1.device();
239
240 let out_shape = l1.shape().clone();
241
242 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-or")?;
243
244 crate::metal_kernels::call_bitwise_or(
245 device.device(),
246 &command_buffer,
247 &crate::metal_kernels::Kernels::new(),
248 s1.dtype(),
249 s1.buffer(),
250 s2.buffer(),
251 l1.start_offset(),
252 l2.start_offset(),
253 out_shape.elem_count(),
254 &output,
255 )
256 .map_err(candle_core::Error::wrap)?;
257
258 let newstorage = candle_core::MetalStorage::new(
259 output,
260 device.clone(),
261 out_shape.elem_count(),
262 s1.dtype(),
263 );
264 Ok((newstorage, out_shape))
265 }
266}
267
268#[allow(dead_code)]
269pub trait BitWiseOp {
270 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
271}
272
273impl BitWiseOp for Tensor {
274 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
275 self.apply_op2_no_bwd(rhs, &BitWiseOr)
276 }
277}
278struct Leftshift(usize);
279
280impl Leftshift {
281 fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
282 let offset = T::from_f64(self.0 as f64);
283 vs.into_par_iter().map(|v| *v << offset).collect()
284 }
285}
286
287impl CustomOp1 for Leftshift {
288 fn name(&self) -> &'static str {
289 "left"
290 }
291
292 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
293 if !l1.is_contiguous() {
294 candle_core::bail!("Input tensor s1 must be contiguous");
295 }
296 match s1 {
297 CpuStorage::U8(vs1) => {
298 let result = self.leftshift(vs1);
299 let result = CpuStorage::U8(result);
300 Ok((result, l1.shape().clone()))
301 }
302 CpuStorage::I16(vs1) => {
303 let result = self.leftshift(vs1);
304 let result = CpuStorage::I16(result);
305 Ok((result, l1.shape().clone()))
306 }
307 CpuStorage::U32(vs1) => {
308 let result = self.leftshift(vs1);
309 let result = CpuStorage::U32(result);
310 Ok((result, l1.shape().clone()))
311 }
312 CpuStorage::I64(vs1) => {
313 let result = self.leftshift(vs1);
314 let result = CpuStorage::I64(result);
315 Ok((result, l1.shape().clone()))
316 }
317 CpuStorage::I32(vs1) => {
318 let result = self.leftshift(vs1);
319 let result = CpuStorage::I32(result);
320 Ok((result, l1.shape().clone()))
321 }
322 CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshifr")),
323 CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")),
324 CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")),
325 CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshifr")),
326 CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshifr")),
327 }
328 }
329 #[cfg(feature = "cuda")]
330 fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
331 if !l1.is_contiguous() {
332 candle_core::bail!("Input tensor s1 must be contiguous");
333 }
334 let dev = s1.device().clone();
335 let (d_in1_ptr, elem_count) = match s1.dtype() {
336 DType::U8 => {
337 let d_in1_ptr = *s1
338 .as_cuda_slice::<u8>()?
339 .slice(l1.start_offset()..)
340 .device_ptr() as *const c_void;
341 let elem_count = l1.shape().elem_count();
342 (d_in1_ptr, elem_count)
343 }
344 DType::I16 => {
345 return Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshift"));
346 }
347 DType::U32 => {
348 return Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshift"));
349 }
350 DType::I64 => {
351 return Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshift"));
352 }
353 DType::I32 => {
354 let d_in1_ptr = *s1
355 .as_cuda_slice::<i32>()?
356 .slice(l1.start_offset()..)
357 .device_ptr() as *const c_void;
358 let elem_count = l1.shape().elem_count();
359 (d_in1_ptr, elem_count)
360 }
361 DType::BF16 => {
362 return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift"));
363 }
364 DType::F16 => {
365 return Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshift"));
366 }
367 DType::F32 => {
368 return Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshift"));
369 }
370 DType::F64 => {
371 return Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshift"));
372 }
373 DType::F8E4M3 => {
374 return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshift"));
375 }
376 };
377 let dst = match s1.dtype() {
378 DType::U8 => {
379 let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
380 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
381 unsafe {
382 ffi::mq_leftshift_u8(
383 d_in1_ptr,
384 d_out_ptr,
385 u32::try_from(elem_count)?,
386 self.0 as i32,
387 )
388 };
389 CudaStorage::wrap_cuda_slice(d_out, dev)
390 }
391 DType::I32 => {
392 let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
393 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
394 unsafe {
395 ffi::mq_leftshift_i32(
396 d_in1_ptr,
397 d_out_ptr,
398 u32::try_from(elem_count)?,
399 self.0 as i32,
400 )
401 };
402 CudaStorage::wrap_cuda_slice(d_out, dev)
403 }
404 _ => unreachable!(),
405 };
406 Ok((dst, l1.shape().clone()))
407 }
408 #[cfg(feature = "metal")]
409 fn metal_fwd(
410 &self,
411 s1: &candle_core::MetalStorage,
412 l1: &Layout,
413 ) -> Result<(candle_core::MetalStorage, Shape)> {
414 if !l1.is_contiguous() {
415 candle_core::bail!("Input tensor s1 must be contiguous");
416 }
417
418 let command_buffer = s1.device().command_buffer()?;
419 command_buffer.set_label("bitwise-leftshift");
420
421 let device = s1.device();
422
423 let out_shape = l1.shape().clone();
424
425 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
426
427 crate::metal_kernels::call_bitwise_leftshift(
428 device.device(),
429 &command_buffer,
430 &crate::metal_kernels::Kernels::new(),
431 s1.dtype(),
432 s1.buffer(),
433 l1.start_offset(),
434 self.0 as u32,
435 out_shape.elem_count(),
436 &output,
437 )
438 .map_err(candle_core::Error::wrap)?;
439
440 let newstorage = candle_core::MetalStorage::new(
441 output,
442 device.clone(),
443 out_shape.elem_count(),
444 s1.dtype(),
445 );
446 Ok((newstorage, out_shape))
447 }
448}
449
450#[allow(dead_code)]
451pub trait LeftshiftOp {
452 fn leftshift(&self, n: usize) -> Result<Tensor>;
453}
454
455impl LeftshiftOp for Tensor {
456 fn leftshift(&self, n: usize) -> Result<Tensor> {
457 self.apply_op1_no_bwd(&Leftshift(n))
458 }
459}
460
461mod tests {
462 #[test]
463 fn test_bitwise_or_cpu() {
464 use crate::utils::ops::BitWiseOp;
465 use candle_core::Tensor;
466 let device = candle_core::Device::Cpu;
467 let a =
468 Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
469 let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
470 let c = a.bitwise_or(&b).unwrap().to_vec2::<i32>().unwrap();
471 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
472 }
473
474 #[cfg(feature = "cuda")]
475 #[test]
476 fn test_bitwise_or_cuda() {
477 use crate::utils::ops::BitWiseOp;
478 use candle_core::Tensor;
479 let device = candle_core::Device::new_cuda(0).unwrap();
480 let a =
481 Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
482 let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
483 let c = a.bitwise_or(&b).unwrap().to_vec2::<i32>().unwrap();
484 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
485 }
486
487 #[cfg(feature = "metal")]
488 #[test]
489 fn test_bitwise_or_metal() {
490 use crate::utils::ops::BitWiseOp;
491 use candle_core::Tensor;
492 let device = candle_core::Device::new_metal(0).unwrap();
493 let a =
494 Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
495 let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
496 let c = a.bitwise_or(&b).unwrap().to_vec2::<i32>().unwrap();
497 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
498 }
499
500 #[test]
501 fn test_leftshift_cpu() {
502 use crate::utils::ops::LeftshiftOp;
503 use candle_core::Tensor;
504 let device = candle_core::Device::Cpu;
505 let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
506 let c = a.leftshift(2).unwrap().to_vec2::<i32>().unwrap();
507 assert_eq!(c, [[4, 8], [12, 16], [20, 24]]);
508 }
509
510 #[cfg(feature = "cuda")]
511 #[test]
512 fn test_leftshift_cuda() {
513 use crate::utils::ops::LeftshiftOp;
514 use candle_core::Tensor;
515 let device = candle_core::Device::new_cuda(0).unwrap();
516 let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
517 let c = a.leftshift(2).unwrap().to_vec2::<i32>().unwrap();
518 assert_eq!(c, [[4, 8], [12, 16], [20, 24]]);
519 }
520
521 #[cfg(feature = "metal")]
522 #[test]
523 fn test_leftshift_metal() {
524 use crate::utils::ops::LeftshiftOp;
525 use candle_core::Tensor;
526 let device = candle_core::Device::new_metal(0).unwrap();
527 let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
528 let c = a.leftshift(2).unwrap().to_vec2::<i32>().unwrap();
529 assert_eq!(c, [[4, 8], [12, 16], [20, 24]]);
530 }
531
532 #[cfg(feature = "cuda")]
533 #[test]
534 fn test_bitwise_or_and_leftshift_cuda() {
535 use crate::utils::{ops::BitWiseOp, LeftshiftOp};
536 use candle_core::Tensor;
537 let device = candle_core::Device::new_cuda(0).unwrap();
538 let a = Tensor::from_vec(vec![0b00001111u8], (1,), &device).unwrap();
539 let b = Tensor::from_vec(vec![0b00001111u8], (1,), &device).unwrap();
540 let c = a
541 .leftshift(4)
542 .unwrap()
543 .bitwise_or(&b)
544 .unwrap()
545 .to_vec1::<u8>()
546 .unwrap();
547 let av = a.to_vec1::<u8>().unwrap();
548 let bv = b.to_vec1::<u8>().unwrap();
549 assert_eq!(av, [0b00001111]);
550 assert_eq!(bv, [0b00001111]);
551 assert_eq!(c, [0b11111111]);
552 }
553
554 #[test]
555 fn test_bitpack_8bit_cpu() {
556 use crate::HqqBits;
557 use candle_core::{Device, Tensor};
558 let bits = HqqBits::Eight;
559 let device = Device::Cpu;
560 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
561 let c = bits.bitpack_type()(wq.clone())
562 .unwrap()
563 .to_vec2::<u8>()
564 .unwrap();
565 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
566 }
567
568 #[cfg(feature = "cuda")]
569 #[test]
570 fn test_bitpack_8bit_cuda() {
571 use crate::HqqBits;
572 use candle_core::DType;
573 use candle_core::{Device, Tensor};
574 let bits = HqqBits::Eight;
575 let device = Device::new_cuda(0).unwrap();
576 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
577 let c = bits.bitpack_type()(wq.clone())
578 .unwrap()
579 .to_dtype(DType::U8)
580 .unwrap()
581 .to_vec2::<u8>()
582 .unwrap();
583 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
584 }
585
586 #[cfg(feature = "metal")]
587 #[test]
588 fn test_bitpack_8bit_metal() {
589 use crate::HqqBits;
590 use candle_core::{Device, Tensor};
591 let bits = HqqBits::Eight;
592 let device = Device::new_metal(0).unwrap();
593 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
594 let c = bits.bitpack_type()(wq.clone())
595 .unwrap()
596 .to_vec2::<u8>()
597 .unwrap();
598 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
599 }
600
601 #[test]
602 fn test_bitpack_4bit() {
603 use crate::HqqBits;
604 use candle_core::{Device, Tensor};
605 let bits = HqqBits::Four;
606 let device = Device::Cpu;
607 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
608 let c = bits.bitpack_type()(wq.clone())
609 .unwrap()
610 .to_vec2::<u8>()
611 .unwrap();
612 assert_eq!(c, [[19, 36]]);
613 }
614
615 #[cfg(feature = "cuda")]
616 #[test]
617 fn test_bitpack_4bit_cuda() {
618 use crate::HqqBits;
619 use candle_core::{Device, Tensor};
620 let bits = HqqBits::Four;
621 let device = Device::new_cuda(0).unwrap();
622 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
623 let c = bits.bitpack_type()(wq.clone())
624 .unwrap()
625 .to_vec2::<u8>()
626 .unwrap();
627 assert_eq!(c, [[19, 36]]);
628 }
629
630 #[cfg(feature = "metal")]
631 #[test]
632 fn test_bitpack_4bit_metal() {
633 use crate::HqqBits;
634 use candle_core::{Device, Tensor};
635 let bits = HqqBits::Four;
636 let device = Device::new_metal(0).unwrap();
637 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
638 let c = bits.bitpack_type()(wq.clone())
639 .unwrap()
640 .to_vec2::<u8>()
641 .unwrap();
642 assert_eq!(c, [[19, 36]]);
643 }
644}