1use candle_core::{
2 backend::BackendStorage, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout, Result, Shape,
3 Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8 fmt::Display,
9 ops::{BitAnd, BitOr, BitXor, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::ffi;
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19struct Leftshift(usize);
20
21impl Leftshift {
22 fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
23 let offset = T::from_f64(self.0 as f64);
24 vs.into_par_iter().map(|v| *v << offset).collect()
25 }
26}
27
28impl CustomOp1 for Leftshift {
29 fn name(&self) -> &'static str {
30 "left"
31 }
32
33 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
34 match s1 {
35 CpuStorage::U8(vs1) => {
36 let vs1 = match l1.contiguous_offsets() {
37 Some((a, b)) => &vs1[a..b],
38 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
39 };
40 let result = self.leftshift(vs1);
41 let result = CpuStorage::U8(result);
42 Ok((result, l1.shape().clone()))
43 }
44 CpuStorage::I16(vs1) => {
45 let vs1 = match l1.contiguous_offsets() {
46 Some((a, b)) => &vs1[a..b],
47 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
48 };
49 let result = self.leftshift(vs1);
50 let result = CpuStorage::I16(result);
51 Ok((result, l1.shape().clone()))
52 }
53 CpuStorage::U32(vs1) => {
54 let vs1 = match l1.contiguous_offsets() {
55 Some((a, b)) => &vs1[a..b],
56 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
57 };
58 let result = self.leftshift(vs1);
59 let result = CpuStorage::U32(result);
60 Ok((result, l1.shape().clone()))
61 }
62 CpuStorage::I64(vs1) => {
63 let vs1 = match l1.contiguous_offsets() {
64 Some((a, b)) => &vs1[a..b],
65 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
66 };
67 let result = self.leftshift(vs1);
68 let result = CpuStorage::I64(result);
69 Ok((result, l1.shape().clone()))
70 }
71 CpuStorage::I32(vs1) => {
72 let vs1 = match l1.contiguous_offsets() {
73 Some((a, b)) => &vs1[a..b],
74 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
75 };
76 let result = self.leftshift(vs1);
77 let result = CpuStorage::I32(result);
78 Ok((result, l1.shape().clone()))
79 }
80 CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshifr")),
81 CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")),
82 CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")),
83 CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshifr")),
84 CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshifr")),
85 }
86 }
87 #[cfg(feature = "cuda")]
88 fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
89 if !l1.is_contiguous() {
90 candle_core::bail!("Input tensor s1 must be contiguous");
91 }
92 let dev = s1.device().clone();
93 let (d_in1_ptr, elem_count) = match s1.dtype() {
94 DType::U8 => {
95 let d_in1_ptr = *s1
96 .as_cuda_slice::<u8>()?
97 .slice(l1.start_offset()..)
98 .device_ptr() as *const c_void;
99 let elem_count = l1.shape().elem_count();
100 (d_in1_ptr, elem_count)
101 }
102 DType::I16 => {
103 return Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshift"));
104 }
105 DType::U32 => {
106 return Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshift"));
107 }
108 DType::I64 => {
109 return Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshift"));
110 }
111 DType::I32 => {
112 let d_in1_ptr = *s1
113 .as_cuda_slice::<i32>()?
114 .slice(l1.start_offset()..)
115 .device_ptr() as *const c_void;
116 let elem_count = l1.shape().elem_count();
117 (d_in1_ptr, elem_count)
118 }
119 DType::BF16 => {
120 return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift"));
121 }
122 DType::F16 => {
123 return Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshift"));
124 }
125 DType::F32 => {
126 return Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshift"));
127 }
128 DType::F64 => {
129 return Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshift"));
130 }
131 DType::F8E4M3 => {
132 return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshift"));
133 }
134 };
135 let dst = match s1.dtype() {
136 DType::U8 => {
137 let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
138 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
139 unsafe {
140 ffi::leftshift_u8(
141 d_in1_ptr,
142 d_out_ptr,
143 u32::try_from(elem_count)?,
144 self.0 as i32,
145 )
146 };
147 CudaStorage::wrap_cuda_slice(d_out, dev)
148 }
149 DType::I32 => {
150 let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
151 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
152 unsafe {
153 ffi::leftshift_i32(
154 d_in1_ptr,
155 d_out_ptr,
156 u32::try_from(elem_count)?,
157 self.0 as i32,
158 )
159 };
160 CudaStorage::wrap_cuda_slice(d_out, dev)
161 }
162 _ => unreachable!(),
163 };
164 Ok((dst, l1.shape().clone()))
165 }
166 #[cfg(feature = "metal")]
167 fn metal_fwd(
168 &self,
169 s1: &candle_core::MetalStorage,
170 l1: &Layout,
171 ) -> Result<(candle_core::MetalStorage, Shape)> {
172 if !l1.is_contiguous() {
173 candle_core::bail!("Input tensor s1 must be contiguous");
174 }
175
176 let command_buffer = s1.device().command_buffer()?;
177 command_buffer.set_label("bitwise-leftshift");
178
179 let device = s1.device();
180
181 let out_shape = l1.shape().clone();
182
183 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
184
185 crate::metal_kernels::call_bitwise_leftshift(
186 device.device(),
187 &command_buffer,
188 &crate::metal_kernels::Kernels::new(),
189 s1.dtype(),
190 s1.buffer(),
191 l1.start_offset(),
192 self.0 as u32,
193 out_shape.elem_count(),
194 &output,
195 )
196 .map_err(candle_core::Error::wrap)?;
197
198 let newstorage = candle_core::MetalStorage::new(
199 output,
200 device.clone(),
201 out_shape.elem_count(),
202 s1.dtype(),
203 );
204 Ok((newstorage, out_shape))
205 }
206}
207
208#[allow(dead_code)]
209pub trait LeftshiftOp {
210 fn leftshift(&self, n: usize) -> Result<Tensor>;
211}
212
213impl LeftshiftOp for Tensor {
214 fn leftshift(&self, n: usize) -> Result<Tensor> {
215 self.apply_op1_no_bwd(&Leftshift(n))
216 }
217}
218
219pub enum BitWiseOpEnum {
220 And,
221 Or,
222 Xor,
223}
224
225impl Display for BitWiseOpEnum {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 match self {
228 BitWiseOpEnum::And => write!(f, "And"),
229 BitWiseOpEnum::Or => write!(f, "Or"),
230 BitWiseOpEnum::Xor => write!(f, "Xor"),
231 }
232 }
233}
234
235struct BitWise {
236 pub op: BitWiseOpEnum,
237}
238
239impl BitWise {
240 pub fn new(op: BitWiseOpEnum) -> Self {
241 Self { op }
242 }
243
244 fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
245 &self,
246 vs1: &[T],
247 vs2: &[T],
248 ) -> Vec<T> {
249 vs1.into_par_iter()
250 .zip_eq(vs2)
251 .map(|(v1, v2)| match self.op {
252 BitWiseOpEnum::And => *v1 & *v2,
253 BitWiseOpEnum::Or => *v1 | *v2,
254 BitWiseOpEnum::Xor => *v1 ^ *v2,
255 })
256 .collect()
257 }
258}
259
260impl CustomOp2 for BitWise {
261 fn name(&self) -> &'static str {
262 "bitwise"
263 }
264
265 fn cpu_fwd(
266 &self,
267 s1: &CpuStorage,
268 l1: &Layout,
269 s2: &CpuStorage,
270 l2: &Layout,
271 ) -> Result<(CpuStorage, Shape)> {
272 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
273 return Err(Error::ShapeMismatchBinaryOp {
274 lhs: l1.shape().clone(),
275 rhs: l2.shape().clone(),
276 op: "bitwise-op",
277 });
278 }
279 if s1.dtype() != s2.dtype() {
280 return Err(Error::DTypeMismatchBinaryOp {
281 lhs: s1.dtype(),
282 rhs: s2.dtype(),
283 op: "bitwise-op",
284 });
285 }
286 if !l1.is_contiguous() {
287 candle_core::bail!("Input tensor s1 must be contiguous");
288 }
289 if !l2.is_contiguous() {
290 candle_core::bail!("Input tensor s2 must be contiguous");
291 }
292
293 match s1 {
294 CpuStorage::U8(vs1) => {
295 let vs2 = s2.as_slice::<u8>().unwrap();
296 let vs1 = match l1.contiguous_offsets() {
297 Some((a, b)) => &vs1[a..b],
298 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
299 };
300 let vs2 = match l2.contiguous_offsets() {
301 Some((a, b)) => &vs2[a..b],
302 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
303 };
304 let result = self.bitwise(vs1, vs2);
305 let result = CpuStorage::U8(result);
306 Ok((result, l1.shape().clone()))
307 }
308 CpuStorage::U32(vs1) => {
309 let vs2 = s2.as_slice::<u32>().unwrap();
310 let vs1 = match l1.contiguous_offsets() {
311 Some((a, b)) => &vs1[a..b],
312 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
313 };
314 let vs2 = match l2.contiguous_offsets() {
315 Some((a, b)) => &vs2[a..b],
316 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
317 };
318 let result = self.bitwise(vs1, vs2);
319 let result = CpuStorage::U32(result);
320 Ok((result, l1.shape().clone()))
321 }
322 CpuStorage::I64(vs1) => {
323 let vs2 = s2.as_slice::<i64>().unwrap();
324 let vs1 = match l1.contiguous_offsets() {
325 Some((a, b)) => &vs1[a..b],
326 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
327 };
328 let vs2 = match l2.contiguous_offsets() {
329 Some((a, b)) => &vs2[a..b],
330 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
331 };
332 let result = self.bitwise(vs1, vs2);
333 let result = CpuStorage::I64(result);
334 Ok((result, l1.shape().clone()))
335 }
336 CpuStorage::I16(vs1) => {
337 let vs2 = s2.as_slice::<i16>().unwrap();
338 let vs1 = match l1.contiguous_offsets() {
339 Some((a, b)) => &vs1[a..b],
340 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
341 };
342 let vs2 = match l2.contiguous_offsets() {
343 Some((a, b)) => &vs2[a..b],
344 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
345 };
346 let result = self.bitwise(vs1, vs2);
347 let result = CpuStorage::I16(result);
348 Ok((result, l1.shape().clone()))
349 }
350 CpuStorage::I32(vs1) => {
351 let vs2 = s2.as_slice::<i32>().unwrap();
352 let vs1 = match l1.contiguous_offsets() {
353 Some((a, b)) => &vs1[a..b],
354 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
355 };
356 let vs2 = match l2.contiguous_offsets() {
357 Some((a, b)) => &vs2[a..b],
358 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
359 };
360 let result = self.bitwise(vs1, vs2);
361 let result = CpuStorage::I32(result);
362 Ok((result, l1.shape().clone()))
363 }
364 CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")),
365 CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")),
366 CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")),
367 CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")),
368 CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")),
369 }
370 }
371
372 #[cfg(feature = "cuda")]
373 fn cuda_fwd(
374 &self,
375 s1: &CudaStorage,
376 l1: &Layout,
377 s2: &CudaStorage,
378 l2: &Layout,
379 ) -> Result<(CudaStorage, Shape)> {
380 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
381 return Err(Error::ShapeMismatchBinaryOp {
382 lhs: l1.shape().clone(),
383 rhs: l2.shape().clone(),
384 op: "bitwise-op",
385 });
386 }
387 if s1.dtype() != s2.dtype() {
388 return Err(Error::DTypeMismatchBinaryOp {
389 lhs: s1.dtype(),
390 rhs: s2.dtype(),
391 op: "bitwise-op",
392 });
393 }
394 if !l1.is_contiguous() {
395 candle_core::bail!("Input tensor s1 must be contiguous");
396 }
397 if !l2.is_contiguous() {
398 candle_core::bail!("Input tensor s2 must be contiguous");
399 }
400
401 let dev = s1.device().clone();
402 let (d_in1_ptr, d_in2_ptr, elem_count) = match s1.dtype() {
403 DType::U8 => {
404 let d_in1_ptr = *s1
405 .as_cuda_slice::<u8>()?
406 .slice(l1.start_offset()..)
407 .device_ptr() as *const c_void;
408 let d_in2_ptr = *s2
409 .as_cuda_slice::<u8>()?
410 .slice(l2.start_offset()..)
411 .device_ptr() as *const c_void;
412 let elem_count = l1.shape().elem_count();
413 (d_in1_ptr, d_in2_ptr, elem_count)
414 }
415 DType::U32 => {
416 let d_in1_ptr = *s1
417 .as_cuda_slice::<u32>()?
418 .slice(l1.start_offset()..)
419 .device_ptr() as *const c_void;
420 let d_in2_ptr = *s2
421 .as_cuda_slice::<u32>()?
422 .slice(l2.start_offset()..)
423 .device_ptr() as *const c_void;
424 let elem_count = l1.shape().elem_count();
425 (d_in1_ptr, d_in2_ptr, elem_count)
426 }
427 DType::I64 => {
428 let d_in1_ptr = *s1
429 .as_cuda_slice::<i64>()?
430 .slice(l1.start_offset()..)
431 .device_ptr() as *const c_void;
432 let d_in2_ptr = *s2
433 .as_cuda_slice::<i64>()?
434 .slice(l2.start_offset()..)
435 .device_ptr() as *const c_void;
436 let elem_count = l1.shape().elem_count();
437 (d_in1_ptr, d_in2_ptr, elem_count)
438 }
439 DType::I32 => {
440 let d_in1_ptr = *s1
441 .as_cuda_slice::<i32>()?
442 .slice(l1.start_offset()..)
443 .device_ptr() as *const c_void;
444 let d_in2_ptr = *s2
445 .as_cuda_slice::<i32>()?
446 .slice(l2.start_offset()..)
447 .device_ptr() as *const c_void;
448 let elem_count = l1.shape().elem_count();
449 (d_in1_ptr, d_in2_ptr, elem_count)
450 }
451 DType::I16 => {
452 let d_in1_ptr = *s1
453 .as_cuda_slice::<i16>()?
454 .slice(l1.start_offset()..)
455 .device_ptr() as *const c_void;
456 let d_in2_ptr = *s2
457 .as_cuda_slice::<i16>()?
458 .slice(l2.start_offset()..)
459 .device_ptr() as *const c_void;
460 let elem_count = l1.shape().elem_count();
461 (d_in1_ptr, d_in2_ptr, elem_count)
462 }
463 DType::BF16 => {
464 return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise"));
465 }
466 DType::F16 => {
467 return Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise"));
468 }
469 DType::F32 => {
470 return Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise"));
471 }
472 DType::F64 => {
473 return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise"));
474 }
475 DType::F8E4M3 => {
476 return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise"));
477 }
478 };
479 let dst = match s1.dtype() {
480 DType::U8 => {
481 let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
482 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
483 unsafe {
484 match self.op {
485 BitWiseOpEnum::And => ffi::bitwise_and_u8(
486 d_in1_ptr,
487 d_in2_ptr,
488 d_out_ptr,
489 u32::try_from(elem_count)?,
490 ),
491 BitWiseOpEnum::Or => ffi::bitwise_or_u8(
492 d_in1_ptr,
493 d_in2_ptr,
494 d_out_ptr,
495 u32::try_from(elem_count)?,
496 ),
497 BitWiseOpEnum::Xor => ffi::bitwise_xor_u8(
498 d_in1_ptr,
499 d_in2_ptr,
500 d_out_ptr,
501 u32::try_from(elem_count)?,
502 ),
503 }
504 };
505 CudaStorage::wrap_cuda_slice(d_out, dev)
506 }
507 DType::U32 => {
508 let d_out = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
509 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
510 unsafe {
511 match self.op {
512 BitWiseOpEnum::And => ffi::bitwise_and_u32(
513 d_in1_ptr,
514 d_in2_ptr,
515 d_out_ptr,
516 u32::try_from(elem_count)?,
517 ),
518 BitWiseOpEnum::Or => ffi::bitwise_or_u32(
519 d_in1_ptr,
520 d_in2_ptr,
521 d_out_ptr,
522 u32::try_from(elem_count)?,
523 ),
524 BitWiseOpEnum::Xor => ffi::bitwise_xor_u32(
525 d_in1_ptr,
526 d_in2_ptr,
527 d_out_ptr,
528 u32::try_from(elem_count)?,
529 ),
530 }
531 };
532 CudaStorage::wrap_cuda_slice(d_out, dev)
533 }
534 DType::I64 => {
535 let d_out = unsafe { dev.alloc::<i64>(elem_count) }.w()?;
536 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
537 unsafe {
538 match self.op {
539 BitWiseOpEnum::And => ffi::bitwise_and_i64(
540 d_in1_ptr,
541 d_in2_ptr,
542 d_out_ptr,
543 u32::try_from(elem_count)?,
544 ),
545 BitWiseOpEnum::Or => ffi::bitwise_or_i64(
546 d_in1_ptr,
547 d_in2_ptr,
548 d_out_ptr,
549 u32::try_from(elem_count)?,
550 ),
551 BitWiseOpEnum::Xor => ffi::bitwise_xor_i64(
552 d_in1_ptr,
553 d_in2_ptr,
554 d_out_ptr,
555 u32::try_from(elem_count)?,
556 ),
557 }
558 };
559 CudaStorage::wrap_cuda_slice(d_out, dev)
560 }
561 DType::I32 => {
562 let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
563 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
564 unsafe {
565 match self.op {
566 BitWiseOpEnum::And => ffi::bitwise_and_i32(
567 d_in1_ptr,
568 d_in2_ptr,
569 d_out_ptr,
570 u32::try_from(elem_count)?,
571 ),
572 BitWiseOpEnum::Or => ffi::bitwise_or_i32(
573 d_in1_ptr,
574 d_in2_ptr,
575 d_out_ptr,
576 u32::try_from(elem_count)?,
577 ),
578 BitWiseOpEnum::Xor => ffi::bitwise_xor_i32(
579 d_in1_ptr,
580 d_in2_ptr,
581 d_out_ptr,
582 u32::try_from(elem_count)?,
583 ),
584 }
585 };
586 CudaStorage::wrap_cuda_slice(d_out, dev)
587 }
588 _ => unreachable!(),
589 };
590 Ok((dst, l1.shape().clone()))
591 }
592
593 #[cfg(feature = "metal")]
594 fn metal_fwd(
595 &self,
596 s1: &candle_core::MetalStorage,
597 l1: &Layout,
598 s2: &candle_core::MetalStorage,
599 l2: &Layout,
600 ) -> Result<(candle_core::MetalStorage, Shape)> {
601 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
602 return Err(Error::ShapeMismatchBinaryOp {
603 lhs: l1.shape().clone(),
604 rhs: l2.shape().clone(),
605 op: "bitwise-op",
606 });
607 }
608 if s1.dtype() != s2.dtype() {
609 return Err(Error::DTypeMismatchBinaryOp {
610 lhs: s1.dtype(),
611 rhs: s2.dtype(),
612 op: "bitwise-op",
613 });
614 }
615 if !l1.is_contiguous() {
616 candle_core::bail!("Input tensor s1 must be contiguous");
617 }
618 if !l2.is_contiguous() {
619 candle_core::bail!("Input tensor s2 must be contiguous");
620 }
621
622 let command_buffer = s1.device().command_buffer()?;
623 command_buffer.set_label("bitwise-op");
624
625 let device = s1.device();
626
627 let out_shape = l1.shape().clone();
628
629 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
630
631 match self.op {
632 BitWiseOpEnum::Or => crate::metal_kernels::call_bitwise_or(
633 device.device(),
634 &command_buffer,
635 &crate::metal_kernels::Kernels::new(),
636 s1.dtype(),
637 s1.buffer(),
638 s2.buffer(),
639 l1.start_offset() * s1.dtype().size_in_bytes(),
640 l2.start_offset() * s2.dtype().size_in_bytes(),
641 out_shape.elem_count(),
642 &output,
643 )
644 .map_err(candle_core::Error::wrap)?,
645 BitWiseOpEnum::And => crate::metal_kernels::call_bitwise_and(
646 device.device(),
647 &command_buffer,
648 &crate::metal_kernels::Kernels::new(),
649 s1.dtype(),
650 s1.buffer(),
651 s2.buffer(),
652 l1.start_offset(),
653 l2.start_offset(),
654 out_shape.elem_count(),
655 &output,
656 )
657 .map_err(candle_core::Error::wrap)?,
658 BitWiseOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
659 device.device(),
660 &command_buffer,
661 &crate::metal_kernels::Kernels::new(),
662 s1.dtype(),
663 s1.buffer(),
664 s2.buffer(),
665 l1.start_offset(),
666 l2.start_offset(),
667 out_shape.elem_count(),
668 &output,
669 )
670 .map_err(candle_core::Error::wrap)?,
671 }
672
673 let newstorage = candle_core::MetalStorage::new(
674 output,
675 device.clone(),
676 out_shape.elem_count(),
677 s1.dtype(),
678 );
679 Ok((newstorage, out_shape))
680 }
681}
682
683#[allow(dead_code)]
684pub trait BitWiseOp {
685 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
686 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
687 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
688}
689
690impl BitWiseOp for Tensor {
691 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
692 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::And))
693 }
694
695 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
696 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::Or))
697 }
698
699 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
700 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::Xor))
701 }
702}
703
704struct NonZero;
705
706impl NonZero {
707 fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
709 let n = layout.dims().len();
710 let mut result = Vec::new();
711 let mut indices = vec![0u32; n];
712 for (i, v) in vs.iter().enumerate() {
713 if !v.is_zero() {
714 let mut idx = i;
715 for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
716 let d = idx % dim;
717 indices[dim_index] = u32::try_from(d).unwrap();
718 idx /= dim;
719 }
720 result.extend_from_slice(&indices);
721 }
722 }
723 result
724 }
725}
726
727#[cfg(feature = "cuda")]
728fn count_nonzero_cuda(
729 dtype: candle_core::DType,
730 d_in: *const c_void,
731 n: u32,
732 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
733) -> u32 {
734 unsafe {
735 match dtype {
736 candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
737 candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
738 candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
739 candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
740 candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
741 candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
742 candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
743 candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
744 candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
745 candle_core::DType::F8E4M3 => todo!(),
746 }
747 }
748}
749
750#[allow(clippy::too_many_arguments)]
751#[cfg(feature = "cuda")]
752fn nonzero_cuda(
753 dtype: candle_core::DType,
754 d_in: *const c_void,
755 n: u32,
756 num_nonzero: u32,
757 dims: *const c_void,
758 num_dims: u32,
759 d_out: *mut c_void,
760 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
761) {
762 unsafe {
763 match dtype {
764 candle_core::DType::U8 => {
765 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
766 }
767 candle_core::DType::U32 => {
768 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
769 }
770 candle_core::DType::I64 => {
771 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
772 }
773 candle_core::DType::I32 => {
774 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
775 }
776 candle_core::DType::I16 => {
777 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
778 }
779 candle_core::DType::BF16 => {
780 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
781 }
782 candle_core::DType::F16 => {
783 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
784 }
785 candle_core::DType::F32 => {
786 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
787 }
788 candle_core::DType::F64 => {
789 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
790 }
791 candle_core::DType::F8E4M3 => todo!(),
792 }
793 }
794}
795
796impl CustomOp1 for NonZero {
797 fn name(&self) -> &'static str {
798 "nonzero"
799 }
800
801 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
802 if !layout.is_contiguous() {
803 return Err(Error::RequiresContiguous { op: "nonzero" });
804 }
805 let result = match storage {
806 candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
807 candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
808 candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
809 candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
810 candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
811 candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
812 candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
813 candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
814 candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
815 candle_core::CpuStorage::F8E4M3(_vs) => todo!(),
816 };
817 let index_len = layout.dims().len();
818 let result_len = result.len() / index_len;
819 let result = CpuStorage::U32(result);
820 let shape = Shape::from_dims(&[result_len, index_len]);
821 Ok((result, shape))
822 }
823 #[cfg(feature = "cuda")]
824 fn cuda_fwd(
825 &self,
826 storage: &candle_core::CudaStorage,
827 layout: &Layout,
828 ) -> Result<(candle_core::CudaStorage, Shape)> {
829 if !layout.is_contiguous() {
830 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
831 }
832 let dev = storage.device().clone();
833 let d_in = match storage.dtype() {
834 candle_core::DType::U8 => *storage.as_cuda_slice::<u8>()?.device_ptr(),
835 candle_core::DType::U32 => *storage.as_cuda_slice::<u32>()?.device_ptr(),
836 candle_core::DType::I32 => *storage.as_cuda_slice::<i32>()?.device_ptr(),
837 candle_core::DType::I16 => *storage.as_cuda_slice::<i16>()?.device_ptr(),
838 candle_core::DType::I64 => *storage.as_cuda_slice::<i64>()?.device_ptr(),
839 candle_core::DType::BF16 => *storage.as_cuda_slice::<half::bf16>()?.device_ptr(),
840 candle_core::DType::F16 => *storage.as_cuda_slice::<half::f16>()?.device_ptr(),
841 candle_core::DType::F32 => *storage.as_cuda_slice::<f32>()?.device_ptr(),
842 candle_core::DType::F64 => *storage.as_cuda_slice::<f64>()?.device_ptr(),
843 candle_core::DType::F8E4M3 => todo!(),
844 } as *const c_void;
845 let n = layout.shape().elem_count();
846
847 let num_nonzero =
848 count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?, *dev.cu_stream());
849 let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
850 .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
851 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
852 let dims = layout
853 .dims()
854 .iter()
855 .map(|&x| u32::try_from(x).unwrap())
856 .collect::<Vec<u32>>();
857 let d_dims = dev
858 .htod_copy(dims)
859 .map_err(|_| Error::Msg("Failed to copy dims to device".to_string()))?;
860 let d_dims_ptr = *d_dims.device_ptr() as *const c_void;
861 nonzero_cuda(
862 storage.dtype(),
863 d_in,
864 u32::try_from(n)?,
865 num_nonzero,
866 d_dims_ptr,
867 u32::try_from(layout.dims().len())?,
868 d_out_ptr,
869 *dev.cu_stream(),
870 );
871 let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
872 let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
873 Ok((dst, shape))
874 }
875}
876
877pub trait NonZeroOp {
878 fn nonzero(&self) -> Result<Tensor>;
879}
880
881impl NonZeroOp for Tensor {
882 #[cfg(feature = "metal")]
883 fn nonzero(&self) -> Result<Tensor> {
884 if !self.is_contiguous() {
885 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
886 }
887 let original_device = self.device();
888 self.to_device(&candle_core::Device::Cpu)?
889 .apply_op1_no_bwd(&NonZero)?
890 .to_device(original_device)
891 }
892
893 #[cfg(not(feature = "metal"))]
894 fn nonzero(&self) -> Result<Tensor> {
895 if !self.is_contiguous() {
896 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
897 }
898 self.apply_op1_no_bwd(&NonZero)
899 }
900}
901
902mod tests {
903 #[test]
904 fn test_nonzero_cpu() {
905 use crate::utils::ops::NonZeroOp;
906 use candle_core::Tensor;
907 let device = candle_core::Device::Cpu;
908 let a = Tensor::from_vec(
909 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
910 &[2, 4],
911 &device,
912 )
913 .unwrap();
914 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
915 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
916 }
917
918 #[cfg(feature = "cuda")]
919 #[test]
920 fn test_nonzero_cuda() {
921 use crate::utils::ops::NonZeroOp;
922 use candle_core::Tensor;
923 let device = candle_core::Device::new_cuda(0).unwrap();
924 let a = Tensor::from_vec(
925 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
926 &[2, 4],
927 &device,
928 )
929 .unwrap();
930 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
931 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
932 }
933
934 #[test]
935 fn test_bitwise_and_cpu() {
936 use crate::utils::ops::BitWiseOp;
937 use candle_core::Tensor;
938 let device = candle_core::Device::Cpu;
939 let a =
940 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
941 let b =
942 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
943 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
944 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
945 }
946
947 #[cfg(feature = "cuda")]
948 #[test]
949 fn test_bitwise_and_cuda() {
950 use crate::utils::ops::BitWiseOp;
951 use candle_core::Tensor;
952 let device = candle_core::Device::new_cuda(0).unwrap();
953 let a =
954 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
955 let b =
956 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
957 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
958 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
959 }
960
961 #[test]
962 fn test_bitwise_or_cpu() {
963 use crate::utils::ops::BitWiseOp;
964 use candle_core::Tensor;
965 let device = candle_core::Device::Cpu;
966 let a =
967 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
968 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
969 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
970 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
971 }
972
973 #[cfg(feature = "cuda")]
974 #[test]
975 fn test_bitwise_or_cuda() {
976 use crate::utils::ops::BitWiseOp;
977 use candle_core::Tensor;
978 let device = candle_core::Device::new_cuda(0).unwrap();
979 let a =
980 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
981 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
982 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
983 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
984 }
985
986 #[test]
987 fn test_bitwise_xor_cpu() {
988 use crate::utils::ops::BitWiseOp;
989 use candle_core::Tensor;
990 let device = candle_core::Device::Cpu;
991 let a =
992 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
993 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
994 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
995 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
996 }
997
998 #[cfg(feature = "cuda")]
999 #[test]
1000 fn test_bitwise_xor_cuda() {
1001 use crate::utils::ops::BitWiseOp;
1002 use candle_core::Tensor;
1003 let device = candle_core::Device::new_cuda(0).unwrap();
1004 let a =
1005 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1006 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1007 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1008 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1009 }
1010
1011 #[test]
1012 fn test_nonzero_and() {
1013 use crate::utils::ops::{BitWiseOp, NonZeroOp};
1014 use candle_core::{Device, Tensor};
1015
1016 let input1 = Tensor::from_vec(
1017 vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1018 (10,),
1019 &Device::Cpu,
1020 )
1021 .unwrap();
1022 let input2 = Tensor::from_vec(
1023 vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1024 (10,),
1025 &Device::Cpu,
1026 )
1027 .unwrap();
1028 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1029
1030 let lt = input.lt(0.0).unwrap();
1031 let gt = input.gt(-10.0).unwrap();
1032 let res = lt
1033 .bitwise_and(>)
1034 .unwrap()
1035 .nonzero()
1036 .unwrap()
1037 .to_vec2::<u32>()
1038 .unwrap();
1039
1040 assert_eq!(
1041 res,
1042 [
1043 [0, 3],
1044 [0, 4],
1045 [0, 5],
1046 [0, 6],
1047 [1, 0],
1048 [1, 3],
1049 [1, 5],
1050 [1, 6]
1051 ]
1052 );
1053 }
1054
1055 #[cfg(feature = "cuda")]
1056 #[test]
1057 fn nonzero_and_cuda() {
1058 use crate::utils::ops::{BitWiseOp, NonZeroOp};
1059 use candle_core::{Device, Tensor};
1060
1061 let device = Device::new_cuda(0).unwrap();
1062 let input1 =
1063 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1064 let input2 =
1065 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1066 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1067
1068 let lt = input.lt(0.0).unwrap();
1069 let gt = input.gt(-10.0).unwrap();
1070 let res = lt
1071 .bitwise_and(>)
1072 .unwrap()
1073 .nonzero()
1074 .unwrap()
1075 .to_vec2::<u32>()
1076 .unwrap();
1077
1078 assert_eq!(
1079 res,
1080 [
1081 [0, 3],
1082 [0, 4],
1083 [0, 5],
1084 [0, 6],
1085 [1, 0],
1086 [1, 3],
1087 [1, 5],
1088 [1, 6]
1089 ]
1090 );
1091 }
1092
1093 #[test]
1094 fn test_bitpack_8bit_cpu() {
1095 use crate::HqqBits;
1096 use candle_core::{Device, Tensor};
1097 let bits = HqqBits::Eight;
1098 let device = Device::Cpu;
1099 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1100 let c = bits.bitpack_type()(wq.clone())
1101 .unwrap()
1102 .to_vec2::<u8>()
1103 .unwrap();
1104 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1105 }
1106
1107 #[cfg(feature = "cuda")]
1108 #[test]
1109 fn test_bitpack_8bit_cuda() {
1110 use crate::HqqBits;
1111 use candle_core::DType;
1112 use candle_core::{Device, Tensor};
1113 let bits = HqqBits::Eight;
1114 let device = Device::new_cuda(0).unwrap();
1115 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1116 let c = bits.bitpack_type()(wq.clone())
1117 .unwrap()
1118 .to_dtype(DType::U8)
1119 .unwrap()
1120 .to_vec2::<u8>()
1121 .unwrap();
1122 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1123 }
1124
1125 #[cfg(feature = "metal")]
1126 #[test]
1127 fn test_bitpack_8bit_metal() {
1128 use crate::HqqBits;
1129 use candle_core::{Device, Tensor};
1130 let bits = HqqBits::Eight;
1131 let device = Device::new_metal(0).unwrap();
1132 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1133 let c = bits.bitpack_type()(wq.clone())
1134 .unwrap()
1135 .to_vec2::<u8>()
1136 .unwrap();
1137 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1138 }
1139
1140 #[test]
1141 fn test_bitpack_4bit() {
1142 use crate::HqqBits;
1143 use candle_core::{Device, Tensor};
1144 let bits = HqqBits::Four;
1145 let device = Device::Cpu;
1146 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1147 let c = bits.bitpack_type()(wq.clone())
1148 .unwrap()
1149 .to_vec2::<u8>()
1150 .unwrap();
1151 assert_eq!(c, [[19, 36]]);
1152 }
1153
1154 #[cfg(feature = "cuda")]
1155 #[test]
1156 fn test_bitpack_4bit_cuda() {
1157 use crate::HqqBits;
1158 use candle_core::{Device, Tensor};
1159 let bits = HqqBits::Four;
1160 let device = Device::new_cuda(0).unwrap();
1161 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1162 let c = bits.bitpack_type()(wq.clone())
1163 .unwrap()
1164 .to_vec2::<u8>()
1165 .unwrap();
1166 assert_eq!(c, [[19, 36]]);
1167 }
1168
1169 #[cfg(feature = "metal")]
1170 #[test]
1171 fn test_bitpack_4bit_metal() {
1172 use crate::HqqBits;
1173 use candle_core::{Device, Tensor};
1174 let bits = HqqBits::Four;
1175 let device = Device::new_metal(0).unwrap();
1176 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1177 let c = bits.bitpack_type()(wq.clone())
1178 .unwrap()
1179 .to_vec2::<u8>()
1180 .unwrap();
1181 assert_eq!(c, [[19, 36]]);
1182 }
1183}