1use candle_core::{
2 backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3 Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8 fmt::Display,
9 ops::{BitAnd, BitOr, BitXor, Not, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::{ffi, slice_ptr};
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19#[cfg(feature = "metal")]
20use crate::metal_kernels::SortScratchCache; #[cfg(feature = "metal")]
22use std::sync::OnceLock;
23
24#[cfg(feature = "metal")]
25static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
26
27struct Leftshift(usize);
28
29impl Leftshift {
30 fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
31 let offset = T::from_f64(self.0 as f64);
32 vs.into_par_iter().map(|v| *v << offset).collect()
33 }
34}
35
36impl CustomOp1 for Leftshift {
37 fn name(&self) -> &'static str {
38 "left"
39 }
40
41 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
42 match s1 {
43 CpuStorage::U8(vs1) => {
44 let vs1 = match l1.contiguous_offsets() {
45 Some((a, b)) => &vs1[a..b],
46 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
47 };
48 let result = self.leftshift(vs1);
49 let result = CpuStorage::U8(result);
50 Ok((result, l1.shape().clone()))
51 }
52 CpuStorage::I16(vs1) => {
53 let vs1 = match l1.contiguous_offsets() {
54 Some((a, b)) => &vs1[a..b],
55 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
56 };
57 let result = self.leftshift(vs1);
58 let result = CpuStorage::I16(result);
59 Ok((result, l1.shape().clone()))
60 }
61 CpuStorage::U32(vs1) => {
62 let vs1 = match l1.contiguous_offsets() {
63 Some((a, b)) => &vs1[a..b],
64 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
65 };
66 let result = self.leftshift(vs1);
67 let result = CpuStorage::U32(result);
68 Ok((result, l1.shape().clone()))
69 }
70 CpuStorage::I64(vs1) => {
71 let vs1 = match l1.contiguous_offsets() {
72 Some((a, b)) => &vs1[a..b],
73 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
74 };
75 let result = self.leftshift(vs1);
76 let result = CpuStorage::I64(result);
77 Ok((result, l1.shape().clone()))
78 }
79 CpuStorage::I32(vs1) => {
80 let vs1 = match l1.contiguous_offsets() {
81 Some((a, b)) => &vs1[a..b],
82 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
83 };
84 let result = self.leftshift(vs1);
85 let result = CpuStorage::I32(result);
86 Ok((result, l1.shape().clone()))
87 }
88 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "leftshift")),
89 }
90 }
91
92 #[cfg(feature = "cuda")]
93 fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
94 if !l1.is_contiguous() {
95 candle_core::bail!("Input tensor s1 must be contiguous");
96 }
97 let dev = s1.device().clone();
98 let (d_in1_ptr, _d_guard, elem_count) = match s1.dtype() {
99 DType::U8 => {
100 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
101 let elem_count = l1.shape().elem_count();
102 (d_in1 as *const c_void, d_in1_guard, elem_count)
103 }
104 DType::I32 => {
105 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
106 let elem_count = l1.shape().elem_count();
107 (d_in1 as *const c_void, d_in1_guard, elem_count)
108 }
109 other => {
110 return Err(Error::UnsupportedDTypeForOp(other, "leftshift"));
111 }
112 };
113 let dst = match s1.dtype() {
114 DType::U8 => {
115 let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
116 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
117 unsafe {
118 ffi::leftshift_u8(
119 d_in1_ptr,
120 d_out_ptr as *mut std::ffi::c_void,
121 u32::try_from(elem_count)?,
122 self.0 as i32,
123 )
124 };
125 drop(d_out_guard);
126 CudaStorage::wrap_cuda_slice(d_out, dev)
127 }
128 DType::I32 => {
129 let d_out = unsafe { dev.alloc::<i32>(elem_count) }?;
130 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
131 unsafe {
132 ffi::leftshift_i32(
133 d_in1_ptr,
134 d_out_ptr as *mut std::ffi::c_void,
135 u32::try_from(elem_count)?,
136 self.0 as i32,
137 )
138 };
139 drop(d_out_guard);
140 CudaStorage::wrap_cuda_slice(d_out, dev)
141 }
142 _ => unreachable!(),
143 };
144 Ok((dst, l1.shape().clone()))
145 }
146
147 #[cfg(feature = "metal")]
148 fn metal_fwd(
149 &self,
150 s1: &candle_core::MetalStorage,
151 l1: &Layout,
152 ) -> Result<(candle_core::MetalStorage, Shape)> {
153 if !l1.is_contiguous() {
154 candle_core::bail!("Input tensor s1 must be contiguous");
155 }
156
157 let command_buffer = s1.device().command_buffer()?;
158 command_buffer.set_label("bitwise-leftshift");
159
160 let device = s1.device();
161
162 let out_shape = l1.shape().clone();
163
164 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
165
166 crate::metal_kernels::call_bitwise_leftshift(
167 device.device(),
168 &command_buffer,
169 &crate::metal_kernels::Kernels::new(),
170 s1.dtype(),
171 s1.buffer(),
172 l1.start_offset(),
173 self.0 as u32,
174 out_shape.elem_count(),
175 &output,
176 )
177 .map_err(candle_core::Error::wrap)?;
178
179 let newstorage = candle_core::MetalStorage::new(
180 output,
181 device.clone(),
182 out_shape.elem_count(),
183 s1.dtype(),
184 );
185 Ok((newstorage, out_shape))
186 }
187}
188
189#[allow(dead_code)]
190pub trait LeftshiftOp {
191 fn leftshift(&self, n: usize) -> Result<Tensor>;
192}
193
194impl LeftshiftOp for Tensor {
195 fn leftshift(&self, n: usize) -> Result<Tensor> {
196 self.apply_op1_no_bwd(&Leftshift(n))
197 }
198}
199
200pub enum BitWiseBinaryOpEnum {
201 And,
202 Or,
203 Xor,
204}
205
206impl Display for BitWiseBinaryOpEnum {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 match self {
209 BitWiseBinaryOpEnum::And => write!(f, "And"),
210 BitWiseBinaryOpEnum::Or => write!(f, "Or"),
211 BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
212 }
213 }
214}
215
216pub enum BitWiseUnaryOpEnum {
217 Not,
218}
219
220impl Display for BitWiseUnaryOpEnum {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 match self {
223 BitWiseUnaryOpEnum::Not => write!(f, "Not"),
224 }
225 }
226}
227
228struct BitWise {
229 pub op: BitWiseBinaryOpEnum,
230}
231
232impl BitWise {
233 pub fn new(op: BitWiseBinaryOpEnum) -> Self {
234 Self { op }
235 }
236
237 fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
238 &self,
239 vs1: &[T],
240 vs2: &[T],
241 ) -> Vec<T> {
242 vs1.into_par_iter()
243 .zip_eq(vs2)
244 .map(|(v1, v2)| match self.op {
245 BitWiseBinaryOpEnum::And => *v1 & *v2,
246 BitWiseBinaryOpEnum::Or => *v1 | *v2,
247 BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
248 })
249 .collect()
250 }
251}
252
253impl CustomOp2 for BitWise {
254 fn name(&self) -> &'static str {
255 "bitwise"
256 }
257
258 fn cpu_fwd(
259 &self,
260 s1: &CpuStorage,
261 l1: &Layout,
262 s2: &CpuStorage,
263 l2: &Layout,
264 ) -> Result<(CpuStorage, Shape)> {
265 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
266 return Err(Error::ShapeMismatchBinaryOp {
267 lhs: l1.shape().clone(),
268 rhs: l2.shape().clone(),
269 op: "bitwise-op",
270 });
271 }
272 if s1.dtype() != s2.dtype() {
273 return Err(Error::DTypeMismatchBinaryOp {
274 lhs: s1.dtype(),
275 rhs: s2.dtype(),
276 op: "bitwise-op",
277 });
278 }
279 if !l1.is_contiguous() {
280 candle_core::bail!("Input tensor s1 must be contiguous");
281 }
282 if !l2.is_contiguous() {
283 candle_core::bail!("Input tensor s2 must be contiguous");
284 }
285
286 match s1 {
287 CpuStorage::U8(vs1) => {
288 let vs2 = s2.as_slice::<u8>().unwrap();
289 let vs1 = match l1.contiguous_offsets() {
290 Some((a, b)) => &vs1[a..b],
291 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
292 };
293 let vs2 = match l2.contiguous_offsets() {
294 Some((a, b)) => &vs2[a..b],
295 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
296 };
297 let result = self.bitwise(vs1, vs2);
298 let result = CpuStorage::U8(result);
299 Ok((result, l1.shape().clone()))
300 }
301 CpuStorage::U32(vs1) => {
302 let vs2 = s2.as_slice::<u32>().unwrap();
303 let vs1 = match l1.contiguous_offsets() {
304 Some((a, b)) => &vs1[a..b],
305 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
306 };
307 let vs2 = match l2.contiguous_offsets() {
308 Some((a, b)) => &vs2[a..b],
309 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
310 };
311 let result = self.bitwise(vs1, vs2);
312 let result = CpuStorage::U32(result);
313 Ok((result, l1.shape().clone()))
314 }
315 CpuStorage::I64(vs1) => {
316 let vs2 = s2.as_slice::<i64>().unwrap();
317 let vs1 = match l1.contiguous_offsets() {
318 Some((a, b)) => &vs1[a..b],
319 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
320 };
321 let vs2 = match l2.contiguous_offsets() {
322 Some((a, b)) => &vs2[a..b],
323 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
324 };
325 let result = self.bitwise(vs1, vs2);
326 let result = CpuStorage::I64(result);
327 Ok((result, l1.shape().clone()))
328 }
329 CpuStorage::I16(vs1) => {
330 let vs2 = s2.as_slice::<i16>().unwrap();
331 let vs1 = match l1.contiguous_offsets() {
332 Some((a, b)) => &vs1[a..b],
333 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
334 };
335 let vs2 = match l2.contiguous_offsets() {
336 Some((a, b)) => &vs2[a..b],
337 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
338 };
339 let result = self.bitwise(vs1, vs2);
340 let result = CpuStorage::I16(result);
341 Ok((result, l1.shape().clone()))
342 }
343 CpuStorage::I32(vs1) => {
344 let vs2 = s2.as_slice::<i32>().unwrap();
345 let vs1 = match l1.contiguous_offsets() {
346 Some((a, b)) => &vs1[a..b],
347 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
348 };
349 let vs2 = match l2.contiguous_offsets() {
350 Some((a, b)) => &vs2[a..b],
351 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
352 };
353 let result = self.bitwise(vs1, vs2);
354 let result = CpuStorage::I32(result);
355 Ok((result, l1.shape().clone()))
356 }
357 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
358 }
359 }
360
361 #[cfg(feature = "cuda")]
362 fn cuda_fwd(
363 &self,
364 s1: &CudaStorage,
365 l1: &Layout,
366 s2: &CudaStorage,
367 l2: &Layout,
368 ) -> Result<(CudaStorage, Shape)> {
369 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
370 return Err(Error::ShapeMismatchBinaryOp {
371 lhs: l1.shape().clone(),
372 rhs: l2.shape().clone(),
373 op: "bitwise-op",
374 });
375 }
376 if s1.dtype() != s2.dtype() {
377 return Err(Error::DTypeMismatchBinaryOp {
378 lhs: s1.dtype(),
379 rhs: s2.dtype(),
380 op: "bitwise-op",
381 });
382 }
383 if !l1.is_contiguous() {
384 candle_core::bail!("Input tensor s1 must be contiguous");
385 }
386 if !l2.is_contiguous() {
387 candle_core::bail!("Input tensor s2 must be contiguous");
388 }
389
390 let dev = s1.device().clone();
391 let (d_in1_ptr, d_in2_ptr, _d_in1_guard, _d_in2_guard, elem_count) = match s1.dtype() {
392 DType::U8 => {
393 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
394 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u8>()?, l2.start_offset());
395 let elem_count = l1.shape().elem_count();
396 (
397 d_in1 as *const std::ffi::c_void,
398 d_in2 as *const std::ffi::c_void,
399 d_in1_guard,
400 d_in2_guard,
401 elem_count,
402 )
403 }
404 DType::U32 => {
405 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u32>()?, l1.start_offset());
406 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u32>()?, l2.start_offset());
407 let elem_count = l1.shape().elem_count();
408 (
409 d_in1 as *const std::ffi::c_void,
410 d_in2 as *const std::ffi::c_void,
411 d_in1_guard,
412 d_in2_guard,
413 elem_count,
414 )
415 }
416 DType::I64 => {
417 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i64>()?, l1.start_offset());
418 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i64>()?, l2.start_offset());
419 let elem_count = l1.shape().elem_count();
420 (
421 d_in1 as *const std::ffi::c_void,
422 d_in2 as *const std::ffi::c_void,
423 d_in1_guard,
424 d_in2_guard,
425 elem_count,
426 )
427 }
428 DType::I32 => {
429 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
430 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i32>()?, l2.start_offset());
431 let elem_count = l1.shape().elem_count();
432 (
433 d_in1 as *const std::ffi::c_void,
434 d_in2 as *const std::ffi::c_void,
435 d_in1_guard,
436 d_in2_guard,
437 elem_count,
438 )
439 }
440 DType::I16 => {
441 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i16>()?, l1.start_offset());
442 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i16>()?, l2.start_offset());
443 let elem_count = l1.shape().elem_count();
444 (
445 d_in1 as *const std::ffi::c_void,
446 d_in2 as *const std::ffi::c_void,
447 d_in1_guard,
448 d_in2_guard,
449 elem_count,
450 )
451 }
452 other => {
453 return Err(Error::UnsupportedDTypeForOp(other, "bitwise"));
454 }
455 };
456 let dst = match s1.dtype() {
457 DType::U8 => {
458 let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
459 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
460 unsafe {
461 match self.op {
462 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
463 d_in1_ptr,
464 d_in2_ptr,
465 d_out_ptr as *mut c_void,
466 u32::try_from(elem_count)?,
467 ),
468 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
469 d_in1_ptr,
470 d_in2_ptr,
471 d_out_ptr as *mut c_void,
472 u32::try_from(elem_count)?,
473 ),
474 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
475 d_in1_ptr,
476 d_in2_ptr,
477 d_out_ptr as *mut c_void,
478 u32::try_from(elem_count)?,
479 ),
480 }
481 };
482 drop(d_out_guard);
483 CudaStorage::wrap_cuda_slice(d_out, dev)
484 }
485 DType::U32 => {
486 let d_out = unsafe { dev.alloc::<u32>(elem_count) }?;
487 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
488 unsafe {
489 match self.op {
490 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
491 d_in1_ptr,
492 d_in2_ptr,
493 d_out_ptr as *mut c_void,
494 u32::try_from(elem_count)?,
495 ),
496 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
497 d_in1_ptr,
498 d_in2_ptr,
499 d_out_ptr as *mut c_void,
500 u32::try_from(elem_count)?,
501 ),
502 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
503 d_in1_ptr,
504 d_in2_ptr,
505 d_out_ptr as *mut c_void,
506 u32::try_from(elem_count)?,
507 ),
508 }
509 };
510 drop(d_out_guard);
511 CudaStorage::wrap_cuda_slice(d_out, dev)
512 }
513 DType::I64 => {
514 let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
515 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
516 unsafe {
517 match self.op {
518 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
519 d_in1_ptr,
520 d_in2_ptr,
521 d_out_ptr as *mut c_void,
522 u32::try_from(elem_count)?,
523 ),
524 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
525 d_in1_ptr,
526 d_in2_ptr,
527 d_out_ptr as *mut c_void,
528 u32::try_from(elem_count)?,
529 ),
530 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
531 d_in1_ptr,
532 d_in2_ptr,
533 d_out_ptr as *mut c_void,
534 u32::try_from(elem_count)?,
535 ),
536 }
537 };
538 drop(d_out_guard);
539 CudaStorage::wrap_cuda_slice(d_out, dev)
540 }
541 DType::I32 => {
542 let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
543 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
544 unsafe {
545 match self.op {
546 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
547 d_in1_ptr,
548 d_in2_ptr,
549 d_out_ptr as *mut c_void,
550 u32::try_from(elem_count)?,
551 ),
552 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
553 d_in1_ptr,
554 d_in2_ptr,
555 d_out_ptr as *mut c_void,
556 u32::try_from(elem_count)?,
557 ),
558 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
559 d_in1_ptr,
560 d_in2_ptr,
561 d_out_ptr as *mut c_void,
562 u32::try_from(elem_count)?,
563 ),
564 }
565 };
566 drop(d_out_guard);
567 CudaStorage::wrap_cuda_slice(d_out, dev)
568 }
569 _ => unreachable!(),
570 };
571 Ok((dst, l1.shape().clone()))
572 }
573
574 #[cfg(feature = "metal")]
575 fn metal_fwd(
576 &self,
577 s1: &candle_core::MetalStorage,
578 l1: &Layout,
579 s2: &candle_core::MetalStorage,
580 l2: &Layout,
581 ) -> Result<(candle_core::MetalStorage, Shape)> {
582 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
583 return Err(Error::ShapeMismatchBinaryOp {
584 lhs: l1.shape().clone(),
585 rhs: l2.shape().clone(),
586 op: "bitwise-op",
587 });
588 }
589 if s1.dtype() != s2.dtype() {
590 return Err(Error::DTypeMismatchBinaryOp {
591 lhs: s1.dtype(),
592 rhs: s2.dtype(),
593 op: "bitwise-op",
594 });
595 }
596 if !l1.is_contiguous() {
597 candle_core::bail!("Input tensor s1 must be contiguous");
598 }
599 if !l2.is_contiguous() {
600 candle_core::bail!("Input tensor s2 must be contiguous");
601 }
602
603 let command_buffer = s1.device().command_buffer()?;
604 command_buffer.set_label("bitwise-op");
605
606 let device = s1.device();
607
608 let out_shape = l1.shape().clone();
609
610 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
611
612 match self.op {
613 BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
614 device.device(),
615 &command_buffer,
616 &crate::metal_kernels::Kernels::new(),
617 s1.dtype(),
618 s1.buffer(),
619 s2.buffer(),
620 l1.start_offset() * s1.dtype().size_in_bytes(),
621 l2.start_offset() * s2.dtype().size_in_bytes(),
622 out_shape.elem_count(),
623 &output,
624 )
625 .map_err(candle_core::Error::wrap)?,
626 BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
627 device.device(),
628 &command_buffer,
629 &crate::metal_kernels::Kernels::new(),
630 s1.dtype(),
631 s1.buffer(),
632 s2.buffer(),
633 l1.start_offset() * s1.dtype().size_in_bytes(),
634 l2.start_offset() * s2.dtype().size_in_bytes(),
635 out_shape.elem_count(),
636 &output,
637 )
638 .map_err(candle_core::Error::wrap)?,
639 BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
640 device.device(),
641 &command_buffer,
642 &crate::metal_kernels::Kernels::new(),
643 s1.dtype(),
644 s1.buffer(),
645 s2.buffer(),
646 l1.start_offset() * s1.dtype().size_in_bytes(),
647 l2.start_offset() * s2.dtype().size_in_bytes(),
648 out_shape.elem_count(),
649 &output,
650 )
651 .map_err(candle_core::Error::wrap)?,
652 }
653
654 let newstorage = candle_core::MetalStorage::new(
655 output,
656 device.clone(),
657 out_shape.elem_count(),
658 s1.dtype(),
659 );
660 Ok((newstorage, out_shape))
661 }
662}
663
664struct BitWiseUnary {
665 pub op: BitWiseUnaryOpEnum,
666}
667
668impl BitWiseUnary {
669 pub fn new(op: BitWiseUnaryOpEnum) -> Self {
670 Self { op }
671 }
672
673 fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
674 vs1.into_par_iter()
675 .map(|v1| match self.op {
676 BitWiseUnaryOpEnum::Not => !*v1,
677 })
678 .collect()
679 }
680}
681
682impl CustomOp1 for BitWiseUnary {
683 fn name(&self) -> &'static str {
684 "bitwise-unary"
685 }
686
687 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
688 if !l1.is_contiguous() {
689 candle_core::bail!("Input tensor s1 must be contiguous");
690 }
691
692 match s1 {
693 CpuStorage::U8(vs1) => {
694 let vs1 = match l1.contiguous_offsets() {
695 Some((a, b)) => &vs1[a..b],
696 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
697 };
698 let result = self.bitwise(vs1);
699 let result = CpuStorage::U8(result);
700 Ok((result, l1.shape().clone()))
701 }
702 CpuStorage::U32(vs1) => {
703 let vs1 = match l1.contiguous_offsets() {
704 Some((a, b)) => &vs1[a..b],
705 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
706 };
707 let result = self.bitwise(vs1);
708 let result = CpuStorage::U32(result);
709 Ok((result, l1.shape().clone()))
710 }
711 CpuStorage::I64(vs1) => {
712 let vs1 = match l1.contiguous_offsets() {
713 Some((a, b)) => &vs1[a..b],
714 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
715 };
716 let result = self.bitwise(vs1);
717 let result = CpuStorage::I64(result);
718 Ok((result, l1.shape().clone()))
719 }
720 CpuStorage::I16(vs1) => {
721 let vs1 = match l1.contiguous_offsets() {
722 Some((a, b)) => &vs1[a..b],
723 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
724 };
725 let result = self.bitwise(vs1);
726 let result = CpuStorage::I16(result);
727 Ok((result, l1.shape().clone()))
728 }
729 CpuStorage::I32(vs1) => {
730 let vs1 = match l1.contiguous_offsets() {
731 Some((a, b)) => &vs1[a..b],
732 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
733 };
734 let result = self.bitwise(vs1);
735 let result = CpuStorage::I32(result);
736 Ok((result, l1.shape().clone()))
737 }
738 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
739 }
740 }
741
742 #[cfg(feature = "cuda")]
743 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
744 todo!()
745 }
746
747 #[cfg(feature = "metal")]
748 fn metal_fwd(
749 &self,
750 s1: &candle_core::MetalStorage,
751 l1: &Layout,
752 ) -> Result<(candle_core::MetalStorage, Shape)> {
753 if !l1.is_contiguous() {
754 candle_core::bail!("Input tensor s1 must be contiguous");
755 }
756
757 let command_buffer = s1.device().command_buffer()?;
758 command_buffer.set_label("bitwise-unary-op");
759
760 let device = s1.device();
761
762 let out_shape = l1.shape().clone();
763
764 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
765
766 match self.op {
767 BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
768 device.device(),
769 &command_buffer,
770 &crate::metal_kernels::Kernels::new(),
771 s1.dtype(),
772 s1.buffer(),
773 l1.start_offset() * s1.dtype().size_in_bytes(),
774 out_shape.elem_count(),
775 &output,
776 )
777 .map_err(candle_core::Error::wrap)?,
778 }
779
780 let newstorage = candle_core::MetalStorage::new(
781 output,
782 device.clone(),
783 out_shape.elem_count(),
784 s1.dtype(),
785 );
786 Ok((newstorage, out_shape))
787 }
788}
789
790#[allow(dead_code)]
791pub trait BitWiseOp {
792 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
793 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
794 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
795 fn bitwise_not(&self) -> Result<Tensor>;
796}
797
798impl BitWiseOp for Tensor {
799 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
800 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
801 }
802
803 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
804 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
805 }
806
807 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
808 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
809 }
810
811 fn bitwise_not(&self) -> Result<Tensor> {
812 self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
813 }
814}
815
816#[allow(unused)]
819struct ArgSort {
821 axis: usize,
822}
823
824#[allow(unused)]
825struct Sort {
827 axis: usize,
828}
829
830impl CustomOp1 for ArgSort {
831 fn name(&self) -> &'static str {
832 "argsort"
833 }
834
835 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
837 candle_core::bail!("ArgSort is not implemented for the CPU backend");
838 }
839
840 #[cfg(feature = "cuda")]
842 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
843 candle_core::bail!("ArgSort is not implemented for the CUDA backend");
844 }
845
846 #[cfg(feature = "metal")]
848 fn metal_fwd(
849 &self,
850 s1: &candle_core::MetalStorage,
851 l1: &Layout,
852 ) -> Result<(candle_core::MetalStorage, Shape)> {
853 if !l1.is_contiguous() {
855 candle_core::bail!("Input tensor s1 must be contiguous");
856 }
857
858 let command_buffer = s1.device().command_buffer()?;
860 command_buffer.set_label("argsort");
861
862 let device = s1.device();
863 let out_shape = l1.shape().clone();
864 let elem_count = out_shape.elem_count();
865
866 let output = device.new_buffer(elem_count, candle_core::DType::U32, "argsort")?;
868
869 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
873
874 let dims = l1.dims();
875 let size_sorted_axis = dims[self.axis];
876 let n_rows = l1.shape().elem_count() / size_sorted_axis;
877
878 let tn = 4usize;
880 let mut bn = match size_sorted_axis.div_ceil(tn) {
881 v if v > 256 => 512,
882 v if v > 128 => 256,
883 v if v > 64 => 128,
884 v if v > 32 => 64,
885 _ => 32,
886 };
887 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
888 bn = 256;
889 }
890 let n_per_block = bn * tn;
891 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
892
893 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
895
896 let sort_args = crate::metal_kernels::SortArgs {
900 axis: self.axis,
901 shape: l1.dims(),
902 strides: l1.stride(),
903 out_shape: l1.dims(), out_strides: l1.stride(),
905 in_contiguous: l1.is_contiguous(),
906 in_ty: s1.dtype(),
907 out_ty: candle_core::DType::U32,
908 src: s1.buffer(),
909 src_offset: l1.start_offset(), dst: &output,
911 bn,
912 tn,
913 n_blocks,
914 };
915
916 crate::metal_kernels::call_argsort(
918 device.device(), &command_buffer, &crate::metal_kernels::Kernels::new(),
921 &sort_args,
922 &scratch,
923 )
924 .map_err(candle_core::Error::wrap)?;
925
926 let newstorage = candle_core::MetalStorage::new(
928 output,
929 device.clone(),
930 elem_count,
931 candle_core::DType::U32,
932 );
933 Ok((newstorage, out_shape))
934 }
935}
936
937impl CustomOp1 for Sort {
938 fn name(&self) -> &'static str {
939 "sort"
940 }
941
942 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
944 candle_core::bail!("Sort is not implemented for the CPU backend");
945 }
946
947 #[cfg(feature = "cuda")]
949 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
950 candle_core::bail!("Sort is not implemented for the CUDA backend");
951 }
952
953 #[cfg(feature = "metal")]
955 fn metal_fwd(
956 &self,
957 s1: &candle_core::MetalStorage,
958 l1: &Layout,
959 ) -> Result<(candle_core::MetalStorage, Shape)> {
960 if !l1.is_contiguous() {
962 candle_core::bail!("Input tensor s1 must be contiguous");
963 }
964
965 let command_buffer = s1.device().command_buffer()?;
967 command_buffer.set_label("sort");
968
969 let device = s1.device();
970 let out_shape = l1.shape().clone();
971 let elem_count = out_shape.elem_count();
972
973 let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
975
976 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
980
981 let dims = l1.dims();
982 let size_sorted_axis = dims[self.axis];
983 let n_rows = l1.shape().elem_count() / size_sorted_axis;
984
985 let tn = 4usize;
987 let mut bn = match size_sorted_axis.div_ceil(tn) {
988 v if v > 256 => 512,
989 v if v > 128 => 256,
990 v if v > 64 => 128,
991 v if v > 32 => 64,
992 _ => 32,
993 };
994 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
995 bn = 256;
996 }
997 let n_per_block = bn * tn;
998 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
999
1000 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
1002
1003 let sort_args = crate::metal_kernels::SortArgs {
1007 axis: self.axis,
1008 shape: l1.dims(),
1009 strides: l1.stride(),
1010 out_shape: l1.dims(), out_strides: l1.stride(),
1012 in_contiguous: l1.is_contiguous(),
1013 in_ty: s1.dtype(),
1014 out_ty: s1.dtype(),
1015 src: s1.buffer(),
1016 src_offset: l1.start_offset(), dst: &output,
1018 bn,
1019 tn,
1020 n_blocks,
1021 };
1022
1023 crate::metal_kernels::call_sort(
1025 device.device(), &command_buffer, &crate::metal_kernels::Kernels::new(),
1028 &sort_args,
1029 &scratch,
1030 )
1031 .map_err(candle_core::Error::wrap)?;
1032
1033 let newstorage =
1035 candle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1036 Ok((newstorage, out_shape))
1037 }
1038}
1039
1040pub trait SortOp {
1042 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1044 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1046}
1047
1048impl SortOp for Tensor {
1049 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1050 if self.device().is_cpu() || self.device().is_cuda() {
1051 return self.arg_sort_last_dim(true);
1052 }
1053 self.apply_op1_no_bwd(&ArgSort {
1054 axis: axis.to_index(self.shape(), "argsort")?,
1055 })
1056 }
1057
1058 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1059 if self.device().is_cpu() || self.device().is_cuda() {
1060 return Ok(self.sort_last_dim(true)?.0);
1061 }
1062 self.apply_op1_no_bwd(&Sort {
1063 axis: axis.to_index(self.shape(), "sort")?,
1064 })
1065 }
1066}
1067
1068struct NonZero;
1069
1070impl NonZero {
1071 fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1073 let n = layout.dims().len();
1074 let mut result = Vec::new();
1075 let mut indices = vec![0u32; n];
1076 for (i, v) in vs.iter().enumerate() {
1077 if !v.is_zero() {
1078 let mut idx = i;
1079 for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1080 let d = idx % dim;
1081 indices[dim_index] = u32::try_from(d).unwrap();
1082 idx /= dim;
1083 }
1084 result.extend_from_slice(&indices);
1085 }
1086 }
1087 result
1088 }
1089}
1090
1091#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1092mod cuda_ops_cccl2 {
1093 use super::*;
1094
1095 pub(super) fn count_nonzero_cuda(
1096 dtype: candle_core::DType,
1097 d_in: *const c_void,
1098 n: u32,
1099 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1100 ) -> u32 {
1101 unsafe {
1102 match dtype {
1103 candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1104 candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1105 candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1106 candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1107 candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1108 candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1109 candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1110 candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1111 candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1112 _ => unreachable!(),
1113 }
1114 }
1115 }
1116
1117 #[allow(clippy::too_many_arguments)]
1118 pub(super) fn nonzero_cuda(
1119 dtype: candle_core::DType,
1120 d_in: *const c_void,
1121 n: u32,
1122 num_nonzero: u32,
1123 dims: *const c_void,
1124 num_dims: u32,
1125 d_out: *mut c_void,
1126 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1127 ) {
1128 unsafe {
1129 match dtype {
1130 candle_core::DType::U8 => {
1131 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1132 }
1133 candle_core::DType::U32 => {
1134 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1135 }
1136 candle_core::DType::I64 => {
1137 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1138 }
1139 candle_core::DType::I32 => {
1140 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1141 }
1142 candle_core::DType::I16 => {
1143 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1144 }
1145 candle_core::DType::BF16 => {
1146 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1147 }
1148 candle_core::DType::F16 => {
1149 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1150 }
1151 candle_core::DType::F32 => {
1152 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1153 }
1154 candle_core::DType::F64 => {
1155 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1156 }
1157 _ => unreachable!(),
1158 }
1159 }
1160 }
1161}
1162
1163#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1164mod cuda_ops_cccl3 {
1165 use super::*;
1166
1167 pub(super) fn count_nonzero_cuda(
1168 dtype: candle_core::DType,
1169 d_in: *const c_void,
1170 n: u32,
1171 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1172 ) -> u32 {
1173 unsafe {
1174 match dtype {
1175 candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1176 candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1177 candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1178 candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1179 candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1180 candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1181 candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1182 candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1183 candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1184 _ => unreachable!(),
1185 }
1186 }
1187 }
1188
1189 #[allow(clippy::too_many_arguments)]
1190 pub(super) fn nonzero_cuda(
1191 dtype: candle_core::DType,
1192 d_in: *const c_void,
1193 n: u32,
1194 num_nonzero: u32,
1195 dims: *const c_void,
1196 num_dims: u32,
1197 d_out: *mut c_void,
1198 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1199 ) {
1200 unsafe {
1201 match dtype {
1202 candle_core::DType::U8 => {
1203 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1204 }
1205 candle_core::DType::U32 => {
1206 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1207 }
1208 candle_core::DType::I64 => {
1209 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1210 }
1211 candle_core::DType::I32 => {
1212 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1213 }
1214 candle_core::DType::I16 => {
1215 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1216 }
1217 candle_core::DType::BF16 => {
1218 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1219 }
1220 candle_core::DType::F16 => {
1221 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1222 }
1223 candle_core::DType::F32 => {
1224 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1225 }
1226 candle_core::DType::F64 => {
1227 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1228 }
1229 _ => unreachable!(),
1230 }
1231 }
1232 }
1233}
1234
1235#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1236use cuda_ops_cccl2::{count_nonzero_cuda, nonzero_cuda};
1237#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1238use cuda_ops_cccl3::{count_nonzero_cuda, nonzero_cuda};
1239
1240impl CustomOp1 for NonZero {
1241 fn name(&self) -> &'static str {
1242 "nonzero"
1243 }
1244
1245 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1246 if !layout.is_contiguous() {
1247 return Err(Error::RequiresContiguous { op: "nonzero" });
1248 }
1249 let result = match storage {
1250 candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1251 candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1252 candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1253 candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1254 candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1255 candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1256 candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1257 candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1258 candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1259 _ => unreachable!(),
1260 };
1261 let index_len = layout.dims().len();
1262 let result_len = result.len() / index_len;
1263 let result = CpuStorage::U32(result);
1264 let shape = Shape::from_dims(&[result_len, index_len]);
1265 Ok((result, shape))
1266 }
1267
1268 #[cfg(feature = "cuda")]
1269 fn cuda_fwd(
1270 &self,
1271 storage: &candle_core::CudaStorage,
1272 layout: &Layout,
1273 ) -> Result<(candle_core::CudaStorage, Shape)> {
1274 if !layout.is_contiguous() {
1275 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1276 }
1277 let dev = storage.device().clone();
1278 let (d_in, _d_in_guard) = match storage.dtype() {
1279 candle_core::DType::U8 => {
1280 let slice = storage.as_cuda_slice::<u8>()?;
1281 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1282 (d_in as *const std::ffi::c_void, d_in_guard)
1283 }
1284 candle_core::DType::U32 => {
1285 let slice = storage.as_cuda_slice::<u32>()?;
1286 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1287 (d_in as *const std::ffi::c_void, d_in_guard)
1288 }
1289 candle_core::DType::I32 => {
1290 let slice = storage.as_cuda_slice::<i32>()?;
1291 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1292 (d_in as *const std::ffi::c_void, d_in_guard)
1293 }
1294 candle_core::DType::I16 => {
1295 let slice = storage.as_cuda_slice::<i16>()?;
1296 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1297 (d_in as *const std::ffi::c_void, d_in_guard)
1298 }
1299 candle_core::DType::I64 => {
1300 let slice = storage.as_cuda_slice::<i64>()?;
1301 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1302 (d_in as *const std::ffi::c_void, d_in_guard)
1303 }
1304 candle_core::DType::BF16 => {
1305 let slice = storage.as_cuda_slice::<half::bf16>()?;
1306 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1307 (d_in as *const std::ffi::c_void, d_in_guard)
1308 }
1309 candle_core::DType::F16 => {
1310 let slice = storage.as_cuda_slice::<half::f16>()?;
1311 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1312 (d_in as *const std::ffi::c_void, d_in_guard)
1313 }
1314 candle_core::DType::F32 => {
1315 let slice = storage.as_cuda_slice::<f32>()?;
1316 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1317 (d_in as *const std::ffi::c_void, d_in_guard)
1318 }
1319 candle_core::DType::F64 => {
1320 let slice = storage.as_cuda_slice::<f64>()?;
1321 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1322 (d_in as *const std::ffi::c_void, d_in_guard)
1323 }
1324 _ => unreachable!(),
1325 };
1326 let n = layout.shape().elem_count();
1327
1328 let num_nonzero = count_nonzero_cuda(
1329 storage.dtype(),
1330 d_in,
1331 u32::try_from(n)?,
1332 dev.cuda_stream().cu_stream(),
1333 );
1334 let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1335 .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1336 if num_nonzero != 0 {
1337 let (d_out, _d_out_guard) = d_out.device_ptr(d_out.stream());
1338 let dims = layout
1339 .dims()
1340 .iter()
1341 .map(|&x| u32::try_from(x).unwrap())
1342 .collect::<Vec<u32>>();
1343 let mut d_dims = unsafe { dev.alloc::<u32>(dims.len()) }?;
1344 dev.memcpy_htod(&dims, &mut d_dims)?;
1345 let (d_dims_ptr, _d_dims_guard) = d_dims.device_ptr(d_dims.stream());
1346 nonzero_cuda(
1347 storage.dtype(),
1348 d_in,
1349 u32::try_from(n)?,
1350 num_nonzero,
1351 d_dims_ptr as *const c_void,
1352 u32::try_from(layout.dims().len())?,
1353 d_out as *mut c_void,
1354 dev.cuda_stream().cu_stream(),
1355 );
1356 }
1357 let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1358 let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
1359 Ok((dst, shape))
1360 }
1361}
1362
1363pub trait NonZeroOp {
1364 fn nonzero(&self) -> Result<Tensor>;
1365}
1366
1367impl NonZeroOp for Tensor {
1368 #[cfg(feature = "metal")]
1369 fn nonzero(&self) -> Result<Tensor> {
1370 if !self.is_contiguous() {
1371 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1372 }
1373 let original_device = self.device();
1374 self.to_device(&candle_core::Device::Cpu)?
1375 .apply_op1_no_bwd(&NonZero)?
1376 .to_device(original_device)
1377 }
1378
1379 #[cfg(not(feature = "metal"))]
1380 fn nonzero(&self) -> Result<Tensor> {
1381 if !self.is_contiguous() {
1382 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1383 }
1384 self.apply_op1_no_bwd(&NonZero)
1385 }
1386}
1387
1388struct CumSum {
1389 inclusive: bool,
1390 reverse: bool,
1391 axis: usize,
1392}
1393
1394impl CustomOp1 for CumSum {
1395 fn name(&self) -> &'static str {
1396 "cumsum"
1397 }
1398
1399 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1400 use std::ops::Add;
1401 if !l1.is_contiguous() {
1402 candle_core::bail!("Input tensor s1 must be contiguous");
1403 }
1404 let dims = l1.dims();
1405 let axis = self.axis;
1406 let axis_len = dims[axis];
1407 let (start, end) = l1
1408 .contiguous_offsets()
1409 .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1410
1411 macro_rules! scan_block {
1413 ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1414 let vs: &[$ty] = $vt;
1415 let input = &vs[start..end];
1416 let count = input.len() / axis_len;
1417 let mut result = Vec::<$ty>::with_capacity(input.len());
1418 if !self.reverse {
1419 if self.inclusive {
1420 for block in 0..count {
1421 let base = block * axis_len;
1422 let mut sum = input[base];
1423 result.push(sum);
1424 for j in 1..axis_len {
1425 sum = sum.$add(input[base + j]);
1426 result.push(sum);
1427 }
1428 }
1429 } else {
1430 let init: $ty = $init;
1431 for block in 0..count {
1432 let base = block * axis_len;
1433 let mut sum = init;
1434 for j in 0..axis_len {
1435 result.push(sum);
1436 sum = sum.$add(input[base + j]);
1437 }
1438 }
1439 }
1440 } else {
1441 if self.inclusive {
1442 for block in 0..count {
1443 let base = block * axis_len;
1444 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1445 let mut sum = input[base + axis_len - 1];
1446 temp.push(sum);
1447 for k in 1..axis_len {
1448 let idx = axis_len - 1 - k;
1449 sum = sum.$add(input[base + idx]);
1450 temp.push(sum);
1451 }
1452 temp.reverse();
1453 result.extend(temp);
1454 }
1455 } else {
1456 let init: $ty = $init;
1457 for block in 0..count {
1458 let base = block * axis_len;
1459 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1460 let mut sum = init;
1461 for k in 0..axis_len {
1462 let idx = axis_len - 1 - k;
1463 temp.push(sum);
1464 sum = sum.$add(input[base + idx]);
1465 }
1466 temp.reverse();
1467 result.extend(temp);
1468 }
1469 }
1470 }
1471 result
1472 }};
1473 }
1474 match s1 {
1475 CpuStorage::U8(vs) => {
1476 let result = scan_block!(vs, u8, wrapping_add, 0u8);
1477 Ok((CpuStorage::U8(result), l1.shape().clone()))
1478 }
1479 CpuStorage::I16(vs) => {
1480 let result = scan_block!(vs, i16, add, 0i16);
1481 Ok((CpuStorage::I16(result), l1.shape().clone()))
1482 }
1483 CpuStorage::U32(vs) => {
1484 let result = scan_block!(vs, u32, wrapping_add, 0u32);
1485 Ok((CpuStorage::U32(result), l1.shape().clone()))
1486 }
1487 CpuStorage::I32(vs) => {
1488 let result = scan_block!(vs, i32, add, 0i32);
1489 Ok((CpuStorage::I32(result), l1.shape().clone()))
1490 }
1491 CpuStorage::I64(vs) => {
1492 let result = scan_block!(vs, i64, add, 0i64);
1493 Ok((CpuStorage::I64(result), l1.shape().clone()))
1494 }
1495 CpuStorage::F32(vs) => {
1496 let result = scan_block!(vs, f32, add, 0.0f32);
1497 Ok((CpuStorage::F32(result), l1.shape().clone()))
1498 }
1499 CpuStorage::F64(vs) => {
1500 let result = scan_block!(vs, f64, add, 0.0f64);
1501 Ok((CpuStorage::F64(result), l1.shape().clone()))
1502 }
1503 _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1504 }
1505 }
1506
1507 #[cfg(feature = "cuda")]
1508 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1509 todo!()
1510 }
1511
1512 #[cfg(feature = "metal")]
1513 fn metal_fwd(
1514 &self,
1515 s1: &candle_core::MetalStorage,
1516 l1: &Layout,
1517 ) -> Result<(candle_core::MetalStorage, Shape)> {
1518 use crate::metal_kernels::ScanType;
1519
1520 let command_buffer = s1.device().command_buffer()?;
1521 command_buffer.set_label("cumsum");
1522
1523 let device = s1.device();
1524
1525 let out_shape = l1.shape().clone();
1526
1527 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1528
1529 crate::metal_kernels::call_scan(
1530 device.device(),
1531 &command_buffer,
1532 &crate::metal_kernels::Kernels::new(),
1533 s1.dtype(),
1534 ScanType::Sum,
1535 s1.buffer(),
1536 l1.start_offset() * s1.dtype().size_in_bytes(),
1537 self.axis,
1538 l1.dims(),
1539 l1.stride(),
1540 self.reverse,
1541 self.inclusive,
1542 &output,
1543 )
1544 .map_err(candle_core::Error::wrap)?;
1545
1546 let newstorage = candle_core::MetalStorage::new(
1547 output,
1548 device.clone(),
1549 out_shape.elem_count(),
1550 s1.dtype(),
1551 );
1552 Ok((newstorage, out_shape))
1553 }
1554}
1555
1556#[allow(dead_code)]
1557pub trait CumSumOp {
1558 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1560
1561 fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1562 -> Result<Tensor>;
1563}
1564
1565impl CumSumOp for Tensor {
1566 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1567 self.fast_cumsum_config(axis, false, false)
1568 }
1569
1570 fn fast_cumsum_config<D: Dim>(
1571 &self,
1572 axis: D,
1573 inclusive: bool,
1574 reverse: bool,
1575 ) -> Result<Tensor> {
1576 self.apply_op1_no_bwd(&CumSum {
1577 inclusive,
1578 reverse,
1579 axis: axis.to_index(self.shape(), "cumsum")?,
1580 })
1581 }
1582}
1583
1584mod tests {
1585 #[test]
1586 fn test_cumsum_exclusive_forward_cpu() {
1587 use crate::utils::ops::CumSumOp;
1588 use candle_core::Tensor;
1589 let device = candle_core::Device::Cpu;
1590 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1591 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1592 assert_eq!(b, [0, 1, 3, 6]);
1593 }
1594
1595 #[test]
1596 fn test_cumsum_inclusive_forward_cpu() {
1597 use crate::utils::ops::CumSumOp;
1598 use candle_core::Tensor;
1599 let device = candle_core::Device::Cpu;
1600 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1601 let b = a
1602 .fast_cumsum_config(0, true, false)
1603 .unwrap()
1604 .to_vec1::<i64>()
1605 .unwrap();
1606 assert_eq!(b, [1, 3, 6, 10]);
1607 }
1608
1609 #[test]
1610 fn test_cumsum_exclusive_reverse_cpu() {
1611 use crate::utils::ops::CumSumOp;
1612 use candle_core::Tensor;
1613 let device = candle_core::Device::Cpu;
1614 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1615 let b = a
1616 .fast_cumsum_config(0, false, true)
1617 .unwrap()
1618 .to_vec1::<i64>()
1619 .unwrap();
1620 assert_eq!(b, [9, 7, 4, 0]);
1621 }
1622
1623 #[test]
1624 fn test_cumsum_inclusive_reverse_cpu() {
1625 use crate::utils::ops::CumSumOp;
1626 use candle_core::Tensor;
1627 let device = candle_core::Device::Cpu;
1628 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1629 let b = a
1630 .fast_cumsum_config(0, true, true)
1631 .unwrap()
1632 .to_vec1::<i64>()
1633 .unwrap();
1634 assert_eq!(b, [10, 9, 7, 4]);
1635 }
1636
1637 #[cfg(feature = "metal")]
1638 #[test]
1639 fn test_cumsum_exclusive_forward_metal() {
1640 use crate::utils::ops::CumSumOp;
1641 use candle_core::Tensor;
1642 let device = candle_core::Device::new_metal(0).unwrap();
1643 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1644 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1645 assert_eq!(b, [0, 1, 3, 6]);
1646 }
1647
1648 #[cfg(feature = "metal")]
1649 #[test]
1650 fn test_cumsum_inclusive_forward_metal() {
1651 use crate::utils::ops::CumSumOp;
1652 use candle_core::Tensor;
1653 let device = candle_core::Device::new_metal(0).unwrap();
1654 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1655 let b = a
1656 .fast_cumsum_config(0, true, false)
1657 .unwrap()
1658 .to_vec1::<i64>()
1659 .unwrap();
1660 assert_eq!(b, [1, 3, 6, 10]);
1661 }
1662
1663 #[cfg(feature = "metal")]
1664 #[test]
1665 fn test_cumsum_exclusive_reverse_metal() {
1666 use crate::utils::ops::CumSumOp;
1667 use candle_core::Tensor;
1668 let device = candle_core::Device::new_metal(0).unwrap();
1669 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1670 let b = a
1671 .fast_cumsum_config(0, false, true)
1672 .unwrap()
1673 .to_vec1::<i64>()
1674 .unwrap();
1675 assert_eq!(b, [9, 7, 4, 0]);
1676 }
1677
1678 #[cfg(feature = "metal")]
1679 #[test]
1680 fn test_cumsum_inclusive_reverse_metal() {
1681 use crate::utils::ops::CumSumOp;
1682 use candle_core::Tensor;
1683 let device = candle_core::Device::new_metal(0).unwrap();
1684 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1685 let b = a
1686 .fast_cumsum_config(0, true, true)
1687 .unwrap()
1688 .to_vec1::<i64>()
1689 .unwrap();
1690 assert_eq!(b, [10, 9, 7, 4]);
1691 }
1692
1693 #[test]
1694 fn test_nonzero_cpu() {
1695 use crate::utils::ops::NonZeroOp;
1696 use candle_core::Tensor;
1697 let device = candle_core::Device::Cpu;
1698 let a = Tensor::from_vec(
1699 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1700 &[2, 4],
1701 &device,
1702 )
1703 .unwrap();
1704 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1705 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1706 }
1707
1708 #[cfg(feature = "cuda")]
1709 #[test]
1710 fn test_nonzero_cuda() {
1711 use crate::utils::ops::NonZeroOp;
1712 use candle_core::Tensor;
1713 let device = candle_core::Device::new_cuda(0).unwrap();
1714 let a = Tensor::from_vec(
1715 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1716 &[2, 4],
1717 &device,
1718 )
1719 .unwrap();
1720 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1721 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1722 }
1723
1724 #[test]
1725 fn test_bitwise_and_cpu() {
1726 use crate::utils::ops::BitWiseOp;
1727 use candle_core::Tensor;
1728 let device = candle_core::Device::Cpu;
1729 let a =
1730 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1731 let b =
1732 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1733 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1734 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
1735 }
1736
1737 #[cfg(feature = "cuda")]
1738 #[test]
1739 fn test_bitwise_and_cuda() {
1740 use crate::utils::ops::BitWiseOp;
1741 use candle_core::Tensor;
1742 let device = candle_core::Device::new_cuda(0).unwrap();
1743 let a =
1744 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1745 let b =
1746 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
1747 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1748 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
1749 }
1750
1751 #[test]
1752 fn test_bitwise_or_cpu() {
1753 use crate::utils::ops::BitWiseOp;
1754 use candle_core::Tensor;
1755 let device = candle_core::Device::Cpu;
1756 let a =
1757 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1758 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1759 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1760 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1761 }
1762
1763 #[cfg(feature = "cuda")]
1764 #[test]
1765 fn test_bitwise_or_cuda() {
1766 use crate::utils::ops::BitWiseOp;
1767 use candle_core::Tensor;
1768 let device = candle_core::Device::new_cuda(0).unwrap();
1769 let a =
1770 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1771 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1772 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1773 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1774 }
1775
1776 #[test]
1777 fn test_bitwise_xor_cpu() {
1778 use crate::utils::ops::BitWiseOp;
1779 use candle_core::Tensor;
1780 let device = candle_core::Device::Cpu;
1781 let a =
1782 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1783 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1784 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1785 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1786 }
1787
1788 #[cfg(feature = "cuda")]
1789 #[test]
1790 fn test_bitwise_xor_cuda() {
1791 use crate::utils::ops::BitWiseOp;
1792 use candle_core::Tensor;
1793 let device = candle_core::Device::new_cuda(0).unwrap();
1794 let a =
1795 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1796 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1797 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1798 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1799 }
1800
1801 #[test]
1802 fn test_nonzero_and() {
1803 use crate::utils::ops::{BitWiseOp, NonZeroOp};
1804 use candle_core::{Device, Tensor};
1805
1806 let input1 = Tensor::from_vec(
1807 vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1808 (10,),
1809 &Device::Cpu,
1810 )
1811 .unwrap();
1812 let input2 = Tensor::from_vec(
1813 vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1814 (10,),
1815 &Device::Cpu,
1816 )
1817 .unwrap();
1818 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1819
1820 let lt = input.lt(0.0).unwrap();
1821 let gt = input.gt(-10.0).unwrap();
1822 let res = lt
1823 .bitwise_and(>)
1824 .unwrap()
1825 .nonzero()
1826 .unwrap()
1827 .to_vec2::<u32>()
1828 .unwrap();
1829
1830 assert_eq!(
1831 res,
1832 [
1833 [0, 3],
1834 [0, 4],
1835 [0, 5],
1836 [0, 6],
1837 [1, 0],
1838 [1, 3],
1839 [1, 5],
1840 [1, 6]
1841 ]
1842 );
1843 }
1844
1845 #[cfg(feature = "cuda")]
1846 #[test]
1847 fn nonzero_and_cuda() {
1848 use crate::utils::ops::{BitWiseOp, NonZeroOp};
1849 use candle_core::{Device, Tensor};
1850
1851 let device = Device::new_cuda(0).unwrap();
1852 let input1 =
1853 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1854 let input2 =
1855 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1856 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1857
1858 let lt = input.lt(0.0).unwrap();
1859 let gt = input.gt(-10.0).unwrap();
1860 let res = lt
1861 .bitwise_and(>)
1862 .unwrap()
1863 .nonzero()
1864 .unwrap()
1865 .to_vec2::<u32>()
1866 .unwrap();
1867
1868 assert_eq!(
1869 res,
1870 [
1871 [0, 3],
1872 [0, 4],
1873 [0, 5],
1874 [0, 6],
1875 [1, 0],
1876 [1, 3],
1877 [1, 5],
1878 [1, 6]
1879 ]
1880 );
1881 }
1882
1883 #[test]
1884 fn test_bitpack_8bit_cpu() {
1885 use crate::HqqBits;
1886 use candle_core::{Device, Tensor};
1887 let bits = HqqBits::Eight;
1888 let device = Device::Cpu;
1889 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1890 let c = bits.bitpack_type()(wq.clone())
1891 .unwrap()
1892 .to_vec2::<u8>()
1893 .unwrap();
1894 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1895 }
1896
1897 #[cfg(feature = "cuda")]
1898 #[test]
1899 fn test_bitpack_8bit_cuda() {
1900 use crate::HqqBits;
1901 use candle_core::DType;
1902 use candle_core::{Device, Tensor};
1903 let bits = HqqBits::Eight;
1904 let device = Device::new_cuda(0).unwrap();
1905 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1906 let c = bits.bitpack_type()(wq.clone())
1907 .unwrap()
1908 .to_dtype(DType::U8)
1909 .unwrap()
1910 .to_vec2::<u8>()
1911 .unwrap();
1912 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1913 }
1914
1915 #[cfg(feature = "metal")]
1916 #[test]
1917 fn test_bitpack_8bit_metal() {
1918 use crate::HqqBits;
1919 use candle_core::{Device, Tensor};
1920 let bits = HqqBits::Eight;
1921 let device = Device::new_metal(0).unwrap();
1922 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1923 let c = bits.bitpack_type()(wq.clone())
1924 .unwrap()
1925 .to_vec2::<u8>()
1926 .unwrap();
1927 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1928 }
1929
1930 #[test]
1931 fn test_bitpack_4bit() {
1932 use crate::HqqBits;
1933 use candle_core::{Device, Tensor};
1934 let bits = HqqBits::Four;
1935 let device = Device::Cpu;
1936 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1937 let c = bits.bitpack_type()(wq.clone())
1938 .unwrap()
1939 .to_vec2::<u8>()
1940 .unwrap();
1941 assert_eq!(c, [[19, 36]]);
1942 }
1943
1944 #[cfg(feature = "cuda")]
1945 #[test]
1946 fn test_bitpack_4bit_cuda() {
1947 use crate::HqqBits;
1948 use candle_core::{Device, Tensor};
1949 let bits = HqqBits::Four;
1950 let device = Device::new_cuda(0).unwrap();
1951 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1952 let c = bits.bitpack_type()(wq.clone())
1953 .unwrap()
1954 .to_vec2::<u8>()
1955 .unwrap();
1956 assert_eq!(c, [[19, 36]]);
1957 }
1958
1959 #[cfg(feature = "metal")]
1960 #[test]
1961 fn test_bitpack_4bit_metal() {
1962 use crate::HqqBits;
1963 use candle_core::{Device, Tensor};
1964 let bits = HqqBits::Four;
1965 let device = Device::new_metal(0).unwrap();
1966 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1967 let c = bits.bitpack_type()(wq.clone())
1968 .unwrap()
1969 .to_vec2::<u8>()
1970 .unwrap();
1971 assert_eq!(c, [[19, 36]]);
1972 }
1973 #[cfg(feature = "metal")]
1975 #[test]
1976 fn test_sort_and_argsort_vector_metal() {
1977 use crate::utils::ops::SortOp;
1978 use candle_core::Tensor;
1979
1980 let device = candle_core::Device::new_metal(0).unwrap();
1981 let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
1982
1983 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1985 assert_eq!(sorted, [1, 2, 3, 4]);
1986
1987 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1989 assert_eq!(idx, [1, 3, 0, 2]);
1990 }
1991
1992 #[cfg(feature = "metal")]
1993 #[test]
1994 fn test_sort_and_argsort_matrix_axis1_metal() {
1995 use crate::utils::ops::SortOp;
1996 use candle_core::Tensor;
1997
1998 let device = candle_core::Device::new_metal(0).unwrap();
1999 let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
2003
2004 let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
2006 assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
2007
2008 let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
2010 assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
2011 }
2012
2013 #[cfg(feature = "metal")]
2015 #[test]
2016 fn test_sort_and_argsort_vector_2048_metal() {
2017 use crate::utils::ops::SortOp;
2018 use candle_core::Tensor;
2019
2020 const N: usize = 4096;
2021
2022 let device = candle_core::Device::new_metal(0).expect("Metal device");
2023
2024 let vals: Vec<i32> = (0..N as i32).rev().collect();
2026 let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
2027
2028 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
2030 let expected: Vec<i32> = (0..N as i32).collect();
2031 assert_eq!(sorted, expected);
2032
2033 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
2035 for (i, &v) in idx.iter().enumerate() {
2037 assert_eq!(v as usize, N - 1 - i);
2038 }
2039 }
2040}