1use candle_core::{
2 backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3 Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8 fmt::Display,
9 ops::{BitAnd, BitOr, BitXor, Not, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::{ffi, slice_ptr};
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19#[cfg(feature = "metal")]
20use crate::metal_kernels::SortScratchCache; #[cfg(feature = "metal")]
22use std::sync::OnceLock;
23
24#[cfg(feature = "metal")]
25static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
26
27struct Leftshift(usize);
28
29impl Leftshift {
30 fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
31 let offset = T::from_f64(self.0 as f64);
32 vs.into_par_iter().map(|v| *v << offset).collect()
33 }
34}
35
36impl CustomOp1 for Leftshift {
37 fn name(&self) -> &'static str {
38 "left"
39 }
40
41 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
42 match s1 {
43 CpuStorage::U8(vs1) => {
44 let vs1 = match l1.contiguous_offsets() {
45 Some((a, b)) => &vs1[a..b],
46 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
47 };
48 let result = self.leftshift(vs1);
49 let result = CpuStorage::U8(result);
50 Ok((result, l1.shape().clone()))
51 }
52 CpuStorage::I16(vs1) => {
53 let vs1 = match l1.contiguous_offsets() {
54 Some((a, b)) => &vs1[a..b],
55 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
56 };
57 let result = self.leftshift(vs1);
58 let result = CpuStorage::I16(result);
59 Ok((result, l1.shape().clone()))
60 }
61 CpuStorage::U32(vs1) => {
62 let vs1 = match l1.contiguous_offsets() {
63 Some((a, b)) => &vs1[a..b],
64 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
65 };
66 let result = self.leftshift(vs1);
67 let result = CpuStorage::U32(result);
68 Ok((result, l1.shape().clone()))
69 }
70 CpuStorage::I64(vs1) => {
71 let vs1 = match l1.contiguous_offsets() {
72 Some((a, b)) => &vs1[a..b],
73 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
74 };
75 let result = self.leftshift(vs1);
76 let result = CpuStorage::I64(result);
77 Ok((result, l1.shape().clone()))
78 }
79 CpuStorage::I32(vs1) => {
80 let vs1 = match l1.contiguous_offsets() {
81 Some((a, b)) => &vs1[a..b],
82 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
83 };
84 let result = self.leftshift(vs1);
85 let result = CpuStorage::I32(result);
86 Ok((result, l1.shape().clone()))
87 }
88 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "leftshift")),
89 }
90 }
91
92 #[cfg(feature = "cuda")]
93 fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
94 if !l1.is_contiguous() {
95 candle_core::bail!("Input tensor s1 must be contiguous");
96 }
97 let dev = s1.device().clone();
98 let (d_in1_ptr, _d_guard, elem_count) = match s1.dtype() {
99 DType::U8 => {
100 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
101 let elem_count = l1.shape().elem_count();
102 (d_in1 as *const c_void, d_in1_guard, elem_count)
103 }
104 DType::I32 => {
105 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
106 let elem_count = l1.shape().elem_count();
107 (d_in1 as *const c_void, d_in1_guard, elem_count)
108 }
109 other => {
110 return Err(Error::UnsupportedDTypeForOp(other, "leftshift"));
111 }
112 };
113 let dst = match s1.dtype() {
114 DType::U8 => {
115 let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
116 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
117 unsafe {
118 ffi::leftshift_u8(
119 d_in1_ptr,
120 d_out_ptr as *mut std::ffi::c_void,
121 u32::try_from(elem_count)?,
122 self.0 as i32,
123 )
124 };
125 drop(d_out_guard);
126 CudaStorage::wrap_cuda_slice(d_out, dev)
127 }
128 DType::I32 => {
129 let d_out = unsafe { dev.alloc::<i32>(elem_count) }?;
130 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
131 unsafe {
132 ffi::leftshift_i32(
133 d_in1_ptr,
134 d_out_ptr as *mut std::ffi::c_void,
135 u32::try_from(elem_count)?,
136 self.0 as i32,
137 )
138 };
139 drop(d_out_guard);
140 CudaStorage::wrap_cuda_slice(d_out, dev)
141 }
142 _ => unreachable!(),
143 };
144 Ok((dst, l1.shape().clone()))
145 }
146
147 #[cfg(feature = "metal")]
148 fn metal_fwd(
149 &self,
150 s1: &candle_core::MetalStorage,
151 l1: &Layout,
152 ) -> Result<(candle_core::MetalStorage, Shape)> {
153 if !l1.is_contiguous() {
154 candle_core::bail!("Input tensor s1 must be contiguous");
155 }
156
157 let command_buffer = s1.device().command_buffer()?;
158 command_buffer.set_label("bitwise-leftshift");
159
160 let device = s1.device();
161
162 let out_shape = l1.shape().clone();
163
164 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
165
166 crate::metal_kernels::call_bitwise_leftshift(
167 device.device(),
168 &command_buffer,
169 &crate::metal_kernels::Kernels::new(),
170 s1.dtype(),
171 s1.buffer(),
172 l1.start_offset(),
173 self.0 as u32,
174 out_shape.elem_count(),
175 &output,
176 )
177 .map_err(candle_core::Error::wrap)?;
178
179 let newstorage = candle_core::MetalStorage::new(
180 output,
181 device.clone(),
182 out_shape.elem_count(),
183 s1.dtype(),
184 );
185 Ok((newstorage, out_shape))
186 }
187}
188
189#[allow(dead_code)]
190pub trait LeftshiftOp {
191 fn leftshift(&self, n: usize) -> Result<Tensor>;
192}
193
194impl LeftshiftOp for Tensor {
195 fn leftshift(&self, n: usize) -> Result<Tensor> {
196 self.apply_op1_no_bwd(&Leftshift(n))
197 }
198}
199
200pub enum BitWiseBinaryOpEnum {
201 And,
202 Or,
203 Xor,
204}
205
206impl Display for BitWiseBinaryOpEnum {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 match self {
209 BitWiseBinaryOpEnum::And => write!(f, "And"),
210 BitWiseBinaryOpEnum::Or => write!(f, "Or"),
211 BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
212 }
213 }
214}
215
216pub enum BitWiseUnaryOpEnum {
217 Not,
218}
219
220impl Display for BitWiseUnaryOpEnum {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 match self {
223 BitWiseUnaryOpEnum::Not => write!(f, "Not"),
224 }
225 }
226}
227
228struct BitWise {
229 pub op: BitWiseBinaryOpEnum,
230}
231
232impl BitWise {
233 pub fn new(op: BitWiseBinaryOpEnum) -> Self {
234 Self { op }
235 }
236
237 fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
238 &self,
239 vs1: &[T],
240 vs2: &[T],
241 ) -> Vec<T> {
242 vs1.into_par_iter()
243 .zip_eq(vs2)
244 .map(|(v1, v2)| match self.op {
245 BitWiseBinaryOpEnum::And => *v1 & *v2,
246 BitWiseBinaryOpEnum::Or => *v1 | *v2,
247 BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
248 })
249 .collect()
250 }
251}
252
253impl CustomOp2 for BitWise {
254 fn name(&self) -> &'static str {
255 "bitwise"
256 }
257
258 fn cpu_fwd(
259 &self,
260 s1: &CpuStorage,
261 l1: &Layout,
262 s2: &CpuStorage,
263 l2: &Layout,
264 ) -> Result<(CpuStorage, Shape)> {
265 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
266 return Err(Error::ShapeMismatchBinaryOp {
267 lhs: l1.shape().clone(),
268 rhs: l2.shape().clone(),
269 op: "bitwise-op",
270 });
271 }
272 if s1.dtype() != s2.dtype() {
273 return Err(Error::DTypeMismatchBinaryOp {
274 lhs: s1.dtype(),
275 rhs: s2.dtype(),
276 op: "bitwise-op",
277 });
278 }
279 if !l1.is_contiguous() {
280 candle_core::bail!("Input tensor s1 must be contiguous");
281 }
282 if !l2.is_contiguous() {
283 candle_core::bail!("Input tensor s2 must be contiguous");
284 }
285
286 match s1 {
287 CpuStorage::U8(vs1) => {
288 let vs2 = s2.as_slice::<u8>().unwrap();
289 let vs1 = match l1.contiguous_offsets() {
290 Some((a, b)) => &vs1[a..b],
291 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
292 };
293 let vs2 = match l2.contiguous_offsets() {
294 Some((a, b)) => &vs2[a..b],
295 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
296 };
297 let result = self.bitwise(vs1, vs2);
298 let result = CpuStorage::U8(result);
299 Ok((result, l1.shape().clone()))
300 }
301 CpuStorage::U32(vs1) => {
302 let vs2 = s2.as_slice::<u32>().unwrap();
303 let vs1 = match l1.contiguous_offsets() {
304 Some((a, b)) => &vs1[a..b],
305 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
306 };
307 let vs2 = match l2.contiguous_offsets() {
308 Some((a, b)) => &vs2[a..b],
309 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
310 };
311 let result = self.bitwise(vs1, vs2);
312 let result = CpuStorage::U32(result);
313 Ok((result, l1.shape().clone()))
314 }
315 CpuStorage::I64(vs1) => {
316 let vs2 = s2.as_slice::<i64>().unwrap();
317 let vs1 = match l1.contiguous_offsets() {
318 Some((a, b)) => &vs1[a..b],
319 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
320 };
321 let vs2 = match l2.contiguous_offsets() {
322 Some((a, b)) => &vs2[a..b],
323 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
324 };
325 let result = self.bitwise(vs1, vs2);
326 let result = CpuStorage::I64(result);
327 Ok((result, l1.shape().clone()))
328 }
329 CpuStorage::I16(vs1) => {
330 let vs2 = s2.as_slice::<i16>().unwrap();
331 let vs1 = match l1.contiguous_offsets() {
332 Some((a, b)) => &vs1[a..b],
333 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
334 };
335 let vs2 = match l2.contiguous_offsets() {
336 Some((a, b)) => &vs2[a..b],
337 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
338 };
339 let result = self.bitwise(vs1, vs2);
340 let result = CpuStorage::I16(result);
341 Ok((result, l1.shape().clone()))
342 }
343 CpuStorage::I32(vs1) => {
344 let vs2 = s2.as_slice::<i32>().unwrap();
345 let vs1 = match l1.contiguous_offsets() {
346 Some((a, b)) => &vs1[a..b],
347 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
348 };
349 let vs2 = match l2.contiguous_offsets() {
350 Some((a, b)) => &vs2[a..b],
351 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
352 };
353 let result = self.bitwise(vs1, vs2);
354 let result = CpuStorage::I32(result);
355 Ok((result, l1.shape().clone()))
356 }
357 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
358 }
359 }
360
361 #[cfg(feature = "cuda")]
362 fn cuda_fwd(
363 &self,
364 s1: &CudaStorage,
365 l1: &Layout,
366 s2: &CudaStorage,
367 l2: &Layout,
368 ) -> Result<(CudaStorage, Shape)> {
369 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
370 return Err(Error::ShapeMismatchBinaryOp {
371 lhs: l1.shape().clone(),
372 rhs: l2.shape().clone(),
373 op: "bitwise-op",
374 });
375 }
376 if s1.dtype() != s2.dtype() {
377 return Err(Error::DTypeMismatchBinaryOp {
378 lhs: s1.dtype(),
379 rhs: s2.dtype(),
380 op: "bitwise-op",
381 });
382 }
383 if !l1.is_contiguous() {
384 candle_core::bail!("Input tensor s1 must be contiguous");
385 }
386 if !l2.is_contiguous() {
387 candle_core::bail!("Input tensor s2 must be contiguous");
388 }
389
390 let dev = s1.device().clone();
391 let (d_in1_ptr, d_in2_ptr, _d_in1_guard, _d_in2_guard, elem_count) = match s1.dtype() {
392 DType::U8 => {
393 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
394 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u8>()?, l2.start_offset());
395 let elem_count = l1.shape().elem_count();
396 (
397 d_in1 as *const std::ffi::c_void,
398 d_in2 as *const std::ffi::c_void,
399 d_in1_guard,
400 d_in2_guard,
401 elem_count,
402 )
403 }
404 DType::U32 => {
405 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u32>()?, l1.start_offset());
406 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u32>()?, l2.start_offset());
407 let elem_count = l1.shape().elem_count();
408 (
409 d_in1 as *const std::ffi::c_void,
410 d_in2 as *const std::ffi::c_void,
411 d_in1_guard,
412 d_in2_guard,
413 elem_count,
414 )
415 }
416 DType::I64 => {
417 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i64>()?, l1.start_offset());
418 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i64>()?, l2.start_offset());
419 let elem_count = l1.shape().elem_count();
420 (
421 d_in1 as *const std::ffi::c_void,
422 d_in2 as *const std::ffi::c_void,
423 d_in1_guard,
424 d_in2_guard,
425 elem_count,
426 )
427 }
428 DType::I32 => {
429 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
430 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i32>()?, l2.start_offset());
431 let elem_count = l1.shape().elem_count();
432 (
433 d_in1 as *const std::ffi::c_void,
434 d_in2 as *const std::ffi::c_void,
435 d_in1_guard,
436 d_in2_guard,
437 elem_count,
438 )
439 }
440 DType::I16 => {
441 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i16>()?, l1.start_offset());
442 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i16>()?, l2.start_offset());
443 let elem_count = l1.shape().elem_count();
444 (
445 d_in1 as *const std::ffi::c_void,
446 d_in2 as *const std::ffi::c_void,
447 d_in1_guard,
448 d_in2_guard,
449 elem_count,
450 )
451 }
452 other => {
453 return Err(Error::UnsupportedDTypeForOp(other, "bitwise"));
454 }
455 };
456 let dst = match s1.dtype() {
457 DType::U8 => {
458 let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
459 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
460 unsafe {
461 match self.op {
462 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
463 d_in1_ptr,
464 d_in2_ptr,
465 d_out_ptr as *mut c_void,
466 u32::try_from(elem_count)?,
467 ),
468 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
469 d_in1_ptr,
470 d_in2_ptr,
471 d_out_ptr as *mut c_void,
472 u32::try_from(elem_count)?,
473 ),
474 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
475 d_in1_ptr,
476 d_in2_ptr,
477 d_out_ptr as *mut c_void,
478 u32::try_from(elem_count)?,
479 ),
480 }
481 };
482 drop(d_out_guard);
483 CudaStorage::wrap_cuda_slice(d_out, dev)
484 }
485 DType::U32 => {
486 let d_out = unsafe { dev.alloc::<u32>(elem_count) }?;
487 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
488 unsafe {
489 match self.op {
490 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
491 d_in1_ptr,
492 d_in2_ptr,
493 d_out_ptr as *mut c_void,
494 u32::try_from(elem_count)?,
495 ),
496 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
497 d_in1_ptr,
498 d_in2_ptr,
499 d_out_ptr as *mut c_void,
500 u32::try_from(elem_count)?,
501 ),
502 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
503 d_in1_ptr,
504 d_in2_ptr,
505 d_out_ptr as *mut c_void,
506 u32::try_from(elem_count)?,
507 ),
508 }
509 };
510 drop(d_out_guard);
511 CudaStorage::wrap_cuda_slice(d_out, dev)
512 }
513 DType::I64 => {
514 let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
515 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
516 unsafe {
517 match self.op {
518 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
519 d_in1_ptr,
520 d_in2_ptr,
521 d_out_ptr as *mut c_void,
522 u32::try_from(elem_count)?,
523 ),
524 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
525 d_in1_ptr,
526 d_in2_ptr,
527 d_out_ptr as *mut c_void,
528 u32::try_from(elem_count)?,
529 ),
530 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
531 d_in1_ptr,
532 d_in2_ptr,
533 d_out_ptr as *mut c_void,
534 u32::try_from(elem_count)?,
535 ),
536 }
537 };
538 drop(d_out_guard);
539 CudaStorage::wrap_cuda_slice(d_out, dev)
540 }
541 DType::I32 => {
542 let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
543 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
544 unsafe {
545 match self.op {
546 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
547 d_in1_ptr,
548 d_in2_ptr,
549 d_out_ptr as *mut c_void,
550 u32::try_from(elem_count)?,
551 ),
552 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
553 d_in1_ptr,
554 d_in2_ptr,
555 d_out_ptr as *mut c_void,
556 u32::try_from(elem_count)?,
557 ),
558 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
559 d_in1_ptr,
560 d_in2_ptr,
561 d_out_ptr as *mut c_void,
562 u32::try_from(elem_count)?,
563 ),
564 }
565 };
566 drop(d_out_guard);
567 CudaStorage::wrap_cuda_slice(d_out, dev)
568 }
569 _ => unreachable!(),
570 };
571 Ok((dst, l1.shape().clone()))
572 }
573
574 #[cfg(feature = "metal")]
575 fn metal_fwd(
576 &self,
577 s1: &candle_core::MetalStorage,
578 l1: &Layout,
579 s2: &candle_core::MetalStorage,
580 l2: &Layout,
581 ) -> Result<(candle_core::MetalStorage, Shape)> {
582 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
583 return Err(Error::ShapeMismatchBinaryOp {
584 lhs: l1.shape().clone(),
585 rhs: l2.shape().clone(),
586 op: "bitwise-op",
587 });
588 }
589 if s1.dtype() != s2.dtype() {
590 return Err(Error::DTypeMismatchBinaryOp {
591 lhs: s1.dtype(),
592 rhs: s2.dtype(),
593 op: "bitwise-op",
594 });
595 }
596 if !l1.is_contiguous() {
597 candle_core::bail!("Input tensor s1 must be contiguous");
598 }
599 if !l2.is_contiguous() {
600 candle_core::bail!("Input tensor s2 must be contiguous");
601 }
602
603 let command_buffer = s1.device().command_buffer()?;
604 command_buffer.set_label("bitwise-op");
605
606 let device = s1.device();
607
608 let out_shape = l1.shape().clone();
609
610 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
611
612 match self.op {
613 BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
614 device.device(),
615 &command_buffer,
616 &crate::metal_kernels::Kernels::new(),
617 s1.dtype(),
618 s1.buffer(),
619 s2.buffer(),
620 l1.start_offset() * s1.dtype().size_in_bytes(),
621 l2.start_offset() * s2.dtype().size_in_bytes(),
622 out_shape.elem_count(),
623 &output,
624 )
625 .map_err(candle_core::Error::wrap)?,
626 BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
627 device.device(),
628 &command_buffer,
629 &crate::metal_kernels::Kernels::new(),
630 s1.dtype(),
631 s1.buffer(),
632 s2.buffer(),
633 l1.start_offset() * s1.dtype().size_in_bytes(),
634 l2.start_offset() * s2.dtype().size_in_bytes(),
635 out_shape.elem_count(),
636 &output,
637 )
638 .map_err(candle_core::Error::wrap)?,
639 BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
640 device.device(),
641 &command_buffer,
642 &crate::metal_kernels::Kernels::new(),
643 s1.dtype(),
644 s1.buffer(),
645 s2.buffer(),
646 l1.start_offset() * s1.dtype().size_in_bytes(),
647 l2.start_offset() * s2.dtype().size_in_bytes(),
648 out_shape.elem_count(),
649 &output,
650 )
651 .map_err(candle_core::Error::wrap)?,
652 }
653
654 let newstorage = candle_core::MetalStorage::new(
655 output,
656 device.clone(),
657 out_shape.elem_count(),
658 s1.dtype(),
659 );
660 Ok((newstorage, out_shape))
661 }
662}
663
664struct BitWiseUnary {
665 pub op: BitWiseUnaryOpEnum,
666}
667
668impl BitWiseUnary {
669 pub fn new(op: BitWiseUnaryOpEnum) -> Self {
670 Self { op }
671 }
672
673 fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
674 vs1.into_par_iter()
675 .map(|v1| match self.op {
676 BitWiseUnaryOpEnum::Not => !*v1,
677 })
678 .collect()
679 }
680}
681
682impl CustomOp1 for BitWiseUnary {
683 fn name(&self) -> &'static str {
684 "bitwise-unary"
685 }
686
687 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
688 if !l1.is_contiguous() {
689 candle_core::bail!("Input tensor s1 must be contiguous");
690 }
691
692 match s1 {
693 CpuStorage::U8(vs1) => {
694 let vs1 = match l1.contiguous_offsets() {
695 Some((a, b)) => &vs1[a..b],
696 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
697 };
698 let result = self.bitwise(vs1);
699 let result = CpuStorage::U8(result);
700 Ok((result, l1.shape().clone()))
701 }
702 CpuStorage::U32(vs1) => {
703 let vs1 = match l1.contiguous_offsets() {
704 Some((a, b)) => &vs1[a..b],
705 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
706 };
707 let result = self.bitwise(vs1);
708 let result = CpuStorage::U32(result);
709 Ok((result, l1.shape().clone()))
710 }
711 CpuStorage::I64(vs1) => {
712 let vs1 = match l1.contiguous_offsets() {
713 Some((a, b)) => &vs1[a..b],
714 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
715 };
716 let result = self.bitwise(vs1);
717 let result = CpuStorage::I64(result);
718 Ok((result, l1.shape().clone()))
719 }
720 CpuStorage::I16(vs1) => {
721 let vs1 = match l1.contiguous_offsets() {
722 Some((a, b)) => &vs1[a..b],
723 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
724 };
725 let result = self.bitwise(vs1);
726 let result = CpuStorage::I16(result);
727 Ok((result, l1.shape().clone()))
728 }
729 CpuStorage::I32(vs1) => {
730 let vs1 = match l1.contiguous_offsets() {
731 Some((a, b)) => &vs1[a..b],
732 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
733 };
734 let result = self.bitwise(vs1);
735 let result = CpuStorage::I32(result);
736 Ok((result, l1.shape().clone()))
737 }
738 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
739 }
740 }
741
742 #[cfg(feature = "cuda")]
743 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
744 todo!()
745 }
746
747 #[cfg(feature = "metal")]
748 fn metal_fwd(
749 &self,
750 s1: &candle_core::MetalStorage,
751 l1: &Layout,
752 ) -> Result<(candle_core::MetalStorage, Shape)> {
753 if !l1.is_contiguous() {
754 candle_core::bail!("Input tensor s1 must be contiguous");
755 }
756
757 let command_buffer = s1.device().command_buffer()?;
758 command_buffer.set_label("bitwise-unary-op");
759
760 let device = s1.device();
761
762 let out_shape = l1.shape().clone();
763
764 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
765
766 match self.op {
767 BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
768 device.device(),
769 &command_buffer,
770 &crate::metal_kernels::Kernels::new(),
771 s1.dtype(),
772 s1.buffer(),
773 l1.start_offset() * s1.dtype().size_in_bytes(),
774 out_shape.elem_count(),
775 &output,
776 )
777 .map_err(candle_core::Error::wrap)?,
778 }
779
780 let newstorage = candle_core::MetalStorage::new(
781 output,
782 device.clone(),
783 out_shape.elem_count(),
784 s1.dtype(),
785 );
786 Ok((newstorage, out_shape))
787 }
788}
789
790#[allow(dead_code)]
791pub trait BitWiseOp {
792 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
793 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
794 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
795 fn bitwise_not(&self) -> Result<Tensor>;
796}
797
798impl BitWiseOp for Tensor {
799 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
800 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
801 }
802
803 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
804 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
805 }
806
807 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
808 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
809 }
810
811 fn bitwise_not(&self) -> Result<Tensor> {
812 self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
813 }
814}
815
816#[allow(unused)]
819struct ArgSort {
821 axis: usize,
822}
823
824#[allow(unused)]
825struct Sort {
827 axis: usize,
828}
829
830impl CustomOp1 for ArgSort {
831 fn name(&self) -> &'static str {
832 "argsort"
833 }
834
835 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
837 candle_core::bail!("ArgSort is not implemented for the CPU backend");
838 }
839
840 #[cfg(feature = "cuda")]
842 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
843 candle_core::bail!("ArgSort is not implemented for the CUDA backend");
844 }
845
846 #[cfg(feature = "metal")]
848 fn metal_fwd(
849 &self,
850 s1: &candle_core::MetalStorage,
851 l1: &Layout,
852 ) -> Result<(candle_core::MetalStorage, Shape)> {
853 if !l1.is_contiguous() {
855 candle_core::bail!("Input tensor s1 must be contiguous");
856 }
857
858 let command_buffer = s1.device().command_buffer()?;
860 command_buffer.set_label("argsort");
861
862 let device = s1.device();
863 let out_shape = l1.shape().clone();
864 let elem_count = out_shape.elem_count();
865
866 let output = device.new_buffer(elem_count, candle_core::DType::U32, "argsort")?;
868
869 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
873
874 let dims = l1.dims();
875 let size_sorted_axis = dims[self.axis];
876 let n_rows = l1.shape().elem_count() / size_sorted_axis;
877
878 let tn = 4usize;
880 let mut bn = match size_sorted_axis.div_ceil(tn) {
881 v if v > 256 => 512,
882 v if v > 128 => 256,
883 v if v > 64 => 128,
884 v if v > 32 => 64,
885 _ => 32,
886 };
887 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
888 bn = 256;
889 }
890 let n_per_block = bn * tn;
891 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
892
893 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
895
896 let sort_args = crate::metal_kernels::SortArgs {
900 axis: self.axis,
901 shape: l1.dims(),
902 strides: l1.stride(),
903 out_shape: l1.dims(), out_strides: l1.stride(),
905 in_contiguous: l1.is_contiguous(),
906 in_ty: s1.dtype(),
907 out_ty: candle_core::DType::U32,
908 src: s1.buffer(),
909 src_offset: l1.start_offset(), dst: &output,
911 bn,
912 tn,
913 n_blocks,
914 };
915
916 crate::metal_kernels::call_argsort(
918 device.device(), &command_buffer, &crate::metal_kernels::Kernels::new(),
921 &sort_args,
922 &scratch,
923 )
924 .map_err(candle_core::Error::wrap)?;
925
926 let newstorage = candle_core::MetalStorage::new(
928 output,
929 device.clone(),
930 elem_count,
931 candle_core::DType::U32,
932 );
933 Ok((newstorage, out_shape))
934 }
935}
936
937impl CustomOp1 for Sort {
938 fn name(&self) -> &'static str {
939 "sort"
940 }
941
942 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
944 candle_core::bail!("Sort is not implemented for the CPU backend");
945 }
946
947 #[cfg(feature = "cuda")]
949 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
950 candle_core::bail!("Sort is not implemented for the CUDA backend");
951 }
952
953 #[cfg(feature = "metal")]
955 fn metal_fwd(
956 &self,
957 s1: &candle_core::MetalStorage,
958 l1: &Layout,
959 ) -> Result<(candle_core::MetalStorage, Shape)> {
960 if !l1.is_contiguous() {
962 candle_core::bail!("Input tensor s1 must be contiguous");
963 }
964
965 let command_buffer = s1.device().command_buffer()?;
967 command_buffer.set_label("sort");
968
969 let device = s1.device();
970 let out_shape = l1.shape().clone();
971 let elem_count = out_shape.elem_count();
972
973 let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
975
976 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
980
981 let dims = l1.dims();
982 let size_sorted_axis = dims[self.axis];
983 let n_rows = l1.shape().elem_count() / size_sorted_axis;
984
985 let tn = 4usize;
987 let mut bn = match size_sorted_axis.div_ceil(tn) {
988 v if v > 256 => 512,
989 v if v > 128 => 256,
990 v if v > 64 => 128,
991 v if v > 32 => 64,
992 _ => 32,
993 };
994 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
995 bn = 256;
996 }
997 let n_per_block = bn * tn;
998 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
999
1000 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
1002
1003 let sort_args = crate::metal_kernels::SortArgs {
1007 axis: self.axis,
1008 shape: l1.dims(),
1009 strides: l1.stride(),
1010 out_shape: l1.dims(), out_strides: l1.stride(),
1012 in_contiguous: l1.is_contiguous(),
1013 in_ty: s1.dtype(),
1014 out_ty: s1.dtype(),
1015 src: s1.buffer(),
1016 src_offset: l1.start_offset(), dst: &output,
1018 bn,
1019 tn,
1020 n_blocks,
1021 };
1022
1023 crate::metal_kernels::call_sort(
1025 device.device(), &command_buffer, &crate::metal_kernels::Kernels::new(),
1028 &sort_args,
1029 &scratch,
1030 )
1031 .map_err(candle_core::Error::wrap)?;
1032
1033 let newstorage =
1035 candle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1036 Ok((newstorage, out_shape))
1037 }
1038}
1039
1040pub trait SortOp {
1042 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1044 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1046}
1047
1048impl SortOp for Tensor {
1049 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1050 if self.device().is_cpu() || self.device().is_cuda() {
1051 return self.arg_sort_last_dim(true);
1052 }
1053 self.apply_op1_no_bwd(&ArgSort {
1054 axis: axis.to_index(self.shape(), "argsort")?,
1055 })
1056 }
1057
1058 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1059 if self.device().is_cpu() || self.device().is_cuda() {
1060 return Ok(self.sort_last_dim(true)?.0);
1061 }
1062 self.apply_op1_no_bwd(&Sort {
1063 axis: axis.to_index(self.shape(), "sort")?,
1064 })
1065 }
1066}
1067
1068struct NonZero;
1069
1070impl NonZero {
1071 fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1073 let n = layout.dims().len();
1074 let mut result = Vec::new();
1075 let mut indices = vec![0u32; n];
1076 for (i, v) in vs.iter().enumerate() {
1077 if !v.is_zero() {
1078 let mut idx = i;
1079 for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1080 let d = idx % dim;
1081 indices[dim_index] = u32::try_from(d).unwrap();
1082 idx /= dim;
1083 }
1084 result.extend_from_slice(&indices);
1085 }
1086 }
1087 result
1088 }
1089}
1090
1091#[cfg(feature = "cuda")]
1092fn count_nonzero_cuda(
1093 dtype: candle_core::DType,
1094 d_in: *const c_void,
1095 n: u32,
1096 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1097) -> u32 {
1098 unsafe {
1099 match dtype {
1100 candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1101 candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1102 candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1103 candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1104 candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1105 candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1106 candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1107 candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1108 candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1109 _ => unreachable!(),
1110 }
1111 }
1112}
1113
1114#[allow(clippy::too_many_arguments)]
1115#[cfg(feature = "cuda")]
1116fn nonzero_cuda(
1117 dtype: candle_core::DType,
1118 d_in: *const c_void,
1119 n: u32,
1120 num_nonzero: u32,
1121 dims: *const c_void,
1122 num_dims: u32,
1123 d_out: *mut c_void,
1124 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1125) {
1126 unsafe {
1127 match dtype {
1128 candle_core::DType::U8 => {
1129 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1130 }
1131 candle_core::DType::U32 => {
1132 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1133 }
1134 candle_core::DType::I64 => {
1135 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1136 }
1137 candle_core::DType::I32 => {
1138 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1139 }
1140 candle_core::DType::I16 => {
1141 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1142 }
1143 candle_core::DType::BF16 => {
1144 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1145 }
1146 candle_core::DType::F16 => {
1147 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1148 }
1149 candle_core::DType::F32 => {
1150 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1151 }
1152 candle_core::DType::F64 => {
1153 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1154 }
1155 _ => unreachable!(),
1156 }
1157 }
1158}
1159
1160impl CustomOp1 for NonZero {
1161 fn name(&self) -> &'static str {
1162 "nonzero"
1163 }
1164
1165 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1166 if !layout.is_contiguous() {
1167 return Err(Error::RequiresContiguous { op: "nonzero" });
1168 }
1169 let result = match storage {
1170 candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1171 candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1172 candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1173 candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1174 candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1175 candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1176 candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1177 candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1178 candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1179 _ => unreachable!(),
1180 };
1181 let index_len = layout.dims().len();
1182 let result_len = result.len() / index_len;
1183 let result = CpuStorage::U32(result);
1184 let shape = Shape::from_dims(&[result_len, index_len]);
1185 Ok((result, shape))
1186 }
1187
1188 #[cfg(feature = "cuda")]
1189 fn cuda_fwd(
1190 &self,
1191 storage: &candle_core::CudaStorage,
1192 layout: &Layout,
1193 ) -> Result<(candle_core::CudaStorage, Shape)> {
1194 if !layout.is_contiguous() {
1195 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1196 }
1197 let dev = storage.device().clone();
1198 let (d_in, _d_in_guard) = match storage.dtype() {
1199 candle_core::DType::U8 => {
1200 let slice = storage.as_cuda_slice::<u8>()?;
1201 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1202 (d_in as *const std::ffi::c_void, d_in_guard)
1203 }
1204 candle_core::DType::U32 => {
1205 let slice = storage.as_cuda_slice::<u32>()?;
1206 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1207 (d_in as *const std::ffi::c_void, d_in_guard)
1208 }
1209 candle_core::DType::I32 => {
1210 let slice = storage.as_cuda_slice::<i32>()?;
1211 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1212 (d_in as *const std::ffi::c_void, d_in_guard)
1213 }
1214 candle_core::DType::I16 => {
1215 let slice = storage.as_cuda_slice::<i16>()?;
1216 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1217 (d_in as *const std::ffi::c_void, d_in_guard)
1218 }
1219 candle_core::DType::I64 => {
1220 let slice = storage.as_cuda_slice::<i64>()?;
1221 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1222 (d_in as *const std::ffi::c_void, d_in_guard)
1223 }
1224 candle_core::DType::BF16 => {
1225 let slice = storage.as_cuda_slice::<half::bf16>()?;
1226 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1227 (d_in as *const std::ffi::c_void, d_in_guard)
1228 }
1229 candle_core::DType::F16 => {
1230 let slice = storage.as_cuda_slice::<half::f16>()?;
1231 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1232 (d_in as *const std::ffi::c_void, d_in_guard)
1233 }
1234 candle_core::DType::F32 => {
1235 let slice = storage.as_cuda_slice::<f32>()?;
1236 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1237 (d_in as *const std::ffi::c_void, d_in_guard)
1238 }
1239 candle_core::DType::F64 => {
1240 let slice = storage.as_cuda_slice::<f64>()?;
1241 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1242 (d_in as *const std::ffi::c_void, d_in_guard)
1243 }
1244 _ => unreachable!(),
1245 };
1246 let n = layout.shape().elem_count();
1247
1248 let num_nonzero = count_nonzero_cuda(
1249 storage.dtype(),
1250 d_in,
1251 u32::try_from(n)?,
1252 dev.cuda_stream().cu_stream(),
1253 );
1254 let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1255 .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1256 if num_nonzero != 0 {
1257 let (d_out, _d_out_guard) = d_out.device_ptr(d_out.stream());
1258 let dims = layout
1259 .dims()
1260 .iter()
1261 .map(|&x| u32::try_from(x).unwrap())
1262 .collect::<Vec<u32>>();
1263 let mut d_dims = unsafe { dev.alloc::<u32>(dims.len()) }?;
1264 dev.memcpy_htod(&dims, &mut d_dims)?;
1265 let (d_dims_ptr, _d_dims_guard) = d_dims.device_ptr(d_dims.stream());
1266 nonzero_cuda(
1267 storage.dtype(),
1268 d_in,
1269 u32::try_from(n)?,
1270 num_nonzero,
1271 d_dims_ptr as *const c_void,
1272 u32::try_from(layout.dims().len())?,
1273 d_out as *mut c_void,
1274 dev.cuda_stream().cu_stream(),
1275 );
1276 }
1277 let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1278 let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
1279 Ok((dst, shape))
1280 }
1281}
1282
1283pub trait NonZeroOp {
1284 fn nonzero(&self) -> Result<Tensor>;
1285}
1286
1287impl NonZeroOp for Tensor {
1288 #[cfg(feature = "metal")]
1289 fn nonzero(&self) -> Result<Tensor> {
1290 if !self.is_contiguous() {
1291 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1292 }
1293 let original_device = self.device();
1294 self.to_device(&candle_core::Device::Cpu)?
1295 .apply_op1_no_bwd(&NonZero)?
1296 .to_device(original_device)
1297 }
1298
1299 #[cfg(not(feature = "metal"))]
1300 fn nonzero(&self) -> Result<Tensor> {
1301 if !self.is_contiguous() {
1302 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1303 }
1304 self.apply_op1_no_bwd(&NonZero)
1305 }
1306}
1307
1308struct CumSum {
1309 inclusive: bool,
1310 reverse: bool,
1311 axis: usize,
1312}
1313
1314impl CustomOp1 for CumSum {
1315 fn name(&self) -> &'static str {
1316 "cumsum"
1317 }
1318
1319 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1320 use std::ops::Add;
1321 if !l1.is_contiguous() {
1322 candle_core::bail!("Input tensor s1 must be contiguous");
1323 }
1324 let dims = l1.dims();
1325 let axis = self.axis;
1326 let axis_len = dims[axis];
1327 let (start, end) = l1
1328 .contiguous_offsets()
1329 .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1330
1331 macro_rules! scan_block {
1333 ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1334 let vs: &[$ty] = $vt;
1335 let input = &vs[start..end];
1336 let count = input.len() / axis_len;
1337 let mut result = Vec::<$ty>::with_capacity(input.len());
1338 if !self.reverse {
1339 if self.inclusive {
1340 for block in 0..count {
1341 let base = block * axis_len;
1342 let mut sum = input[base];
1343 result.push(sum);
1344 for j in 1..axis_len {
1345 sum = sum.$add(input[base + j]);
1346 result.push(sum);
1347 }
1348 }
1349 } else {
1350 let init: $ty = $init;
1351 for block in 0..count {
1352 let base = block * axis_len;
1353 let mut sum = init;
1354 for j in 0..axis_len {
1355 result.push(sum);
1356 sum = sum.$add(input[base + j]);
1357 }
1358 }
1359 }
1360 } else {
1361 if self.inclusive {
1362 for block in 0..count {
1363 let base = block * axis_len;
1364 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1365 let mut sum = input[base + axis_len - 1];
1366 temp.push(sum);
1367 for k in 1..axis_len {
1368 let idx = axis_len - 1 - k;
1369 sum = sum.$add(input[base + idx]);
1370 temp.push(sum);
1371 }
1372 temp.reverse();
1373 result.extend(temp);
1374 }
1375 } else {
1376 let init: $ty = $init;
1377 for block in 0..count {
1378 let base = block * axis_len;
1379 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1380 let mut sum = init;
1381 for k in 0..axis_len {
1382 let idx = axis_len - 1 - k;
1383 temp.push(sum);
1384 sum = sum.$add(input[base + idx]);
1385 }
1386 temp.reverse();
1387 result.extend(temp);
1388 }
1389 }
1390 }
1391 result
1392 }};
1393 }
1394 match s1 {
1395 CpuStorage::U8(vs) => {
1396 let result = scan_block!(vs, u8, wrapping_add, 0u8);
1397 Ok((CpuStorage::U8(result), l1.shape().clone()))
1398 }
1399 CpuStorage::I16(vs) => {
1400 let result = scan_block!(vs, i16, add, 0i16);
1401 Ok((CpuStorage::I16(result), l1.shape().clone()))
1402 }
1403 CpuStorage::U32(vs) => {
1404 let result = scan_block!(vs, u32, wrapping_add, 0u32);
1405 Ok((CpuStorage::U32(result), l1.shape().clone()))
1406 }
1407 CpuStorage::I32(vs) => {
1408 let result = scan_block!(vs, i32, add, 0i32);
1409 Ok((CpuStorage::I32(result), l1.shape().clone()))
1410 }
1411 CpuStorage::I64(vs) => {
1412 let result = scan_block!(vs, i64, add, 0i64);
1413 Ok((CpuStorage::I64(result), l1.shape().clone()))
1414 }
1415 CpuStorage::F32(vs) => {
1416 let result = scan_block!(vs, f32, add, 0.0f32);
1417 Ok((CpuStorage::F32(result), l1.shape().clone()))
1418 }
1419 CpuStorage::F64(vs) => {
1420 let result = scan_block!(vs, f64, add, 0.0f64);
1421 Ok((CpuStorage::F64(result), l1.shape().clone()))
1422 }
1423 _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1424 }
1425 }
1426
1427 #[cfg(feature = "cuda")]
1428 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1429 todo!()
1430 }
1431
1432 #[cfg(feature = "metal")]
1433 fn metal_fwd(
1434 &self,
1435 s1: &candle_core::MetalStorage,
1436 l1: &Layout,
1437 ) -> Result<(candle_core::MetalStorage, Shape)> {
1438 use crate::metal_kernels::ScanType;
1439
1440 let command_buffer = s1.device().command_buffer()?;
1441 command_buffer.set_label("cumsum");
1442
1443 let device = s1.device();
1444
1445 let out_shape = l1.shape().clone();
1446
1447 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1448
1449 crate::metal_kernels::call_scan(
1450 device.device(),
1451 &command_buffer,
1452 &crate::metal_kernels::Kernels::new(),
1453 s1.dtype(),
1454 ScanType::Sum,
1455 s1.buffer(),
1456 l1.start_offset() * s1.dtype().size_in_bytes(),
1457 self.axis,
1458 l1.dims(),
1459 l1.stride(),
1460 self.reverse,
1461 self.inclusive,
1462 &output,
1463 )
1464 .map_err(candle_core::Error::wrap)?;
1465
1466 let newstorage = candle_core::MetalStorage::new(
1467 output,
1468 device.clone(),
1469 out_shape.elem_count(),
1470 s1.dtype(),
1471 );
1472 Ok((newstorage, out_shape))
1473 }
1474}
1475
1476#[allow(dead_code)]
1477pub trait CumSumOp {
1478 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1480
1481 fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1482 -> Result<Tensor>;
1483}
1484
1485impl CumSumOp for Tensor {
1486 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1487 self.fast_cumsum_config(axis, false, false)
1488 }
1489
1490 fn fast_cumsum_config<D: Dim>(
1491 &self,
1492 axis: D,
1493 inclusive: bool,
1494 reverse: bool,
1495 ) -> Result<Tensor> {
1496 self.apply_op1_no_bwd(&CumSum {
1497 inclusive,
1498 reverse,
1499 axis: axis.to_index(self.shape(), "cumsum")?,
1500 })
1501 }
1502}
1503
1504mod tests {
1505 #[test]
1506 fn test_cumsum_exclusive_forward_cpu() {
1507 use crate::utils::ops::CumSumOp;
1508 use candle_core::Tensor;
1509 let device = candle_core::Device::Cpu;
1510 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1511 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1512 assert_eq!(b, [0, 1, 3, 6]);
1513 }
1514
1515 #[test]
1516 fn test_cumsum_inclusive_forward_cpu() {
1517 use crate::utils::ops::CumSumOp;
1518 use candle_core::Tensor;
1519 let device = candle_core::Device::Cpu;
1520 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1521 let b = a
1522 .fast_cumsum_config(0, true, false)
1523 .unwrap()
1524 .to_vec1::<i64>()
1525 .unwrap();
1526 assert_eq!(b, [1, 3, 6, 10]);
1527 }
1528
1529 #[test]
1530 fn test_cumsum_exclusive_reverse_cpu() {
1531 use crate::utils::ops::CumSumOp;
1532 use candle_core::Tensor;
1533 let device = candle_core::Device::Cpu;
1534 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1535 let b = a
1536 .fast_cumsum_config(0, false, true)
1537 .unwrap()
1538 .to_vec1::<i64>()
1539 .unwrap();
1540 assert_eq!(b, [9, 7, 4, 0]);
1541 }
1542
1543 #[test]
1544 fn test_cumsum_inclusive_reverse_cpu() {
1545 use crate::utils::ops::CumSumOp;
1546 use candle_core::Tensor;
1547 let device = candle_core::Device::Cpu;
1548 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1549 let b = a
1550 .fast_cumsum_config(0, true, true)
1551 .unwrap()
1552 .to_vec1::<i64>()
1553 .unwrap();
1554 assert_eq!(b, [10, 9, 7, 4]);
1555 }
1556
1557 #[cfg(feature = "metal")]
1558 #[test]
1559 fn test_cumsum_exclusive_forward_metal() {
1560 use crate::utils::ops::CumSumOp;
1561 use candle_core::Tensor;
1562 let device = candle_core::Device::new_metal(0).unwrap();
1563 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1564 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1565 assert_eq!(b, [0, 1, 3, 6]);
1566 }
1567
1568 #[cfg(feature = "metal")]
1569 #[test]
1570 fn test_cumsum_inclusive_forward_metal() {
1571 use crate::utils::ops::CumSumOp;
1572 use candle_core::Tensor;
1573 let device = candle_core::Device::new_metal(0).unwrap();
1574 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1575 let b = a
1576 .fast_cumsum_config(0, true, false)
1577 .unwrap()
1578 .to_vec1::<i64>()
1579 .unwrap();
1580 assert_eq!(b, [1, 3, 6, 10]);
1581 }
1582
1583 #[cfg(feature = "metal")]
1584 #[test]
1585 fn test_cumsum_exclusive_reverse_metal() {
1586 use crate::utils::ops::CumSumOp;
1587 use candle_core::Tensor;
1588 let device = candle_core::Device::new_metal(0).unwrap();
1589 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1590 let b = a
1591 .fast_cumsum_config(0, false, true)
1592 .unwrap()
1593 .to_vec1::<i64>()
1594 .unwrap();
1595 assert_eq!(b, [9, 7, 4, 0]);
1596 }
1597
1598 #[cfg(feature = "metal")]
1599 #[test]
1600 fn test_cumsum_inclusive_reverse_metal() {
1601 use crate::utils::ops::CumSumOp;
1602 use candle_core::Tensor;
1603 let device = candle_core::Device::new_metal(0).unwrap();
1604 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1605 let b = a
1606 .fast_cumsum_config(0, true, true)
1607 .unwrap()
1608 .to_vec1::<i64>()
1609 .unwrap();
1610 assert_eq!(b, [10, 9, 7, 4]);
1611 }
1612
1613 #[test]
1614 fn test_nonzero_cpu() {
1615 use crate::utils::ops::NonZeroOp;
1616 use candle_core::Tensor;
1617 let device = candle_core::Device::Cpu;
1618 let a = Tensor::from_vec(
1619 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1620 &[2, 4],
1621 &device,
1622 )
1623 .unwrap();
1624 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1625 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1626 }
1627
1628 #[cfg(feature = "cuda")]
1629 #[test]
1630 fn test_nonzero_cuda() {
1631 use crate::utils::ops::NonZeroOp;
1632 use candle_core::Tensor;
1633 let device = candle_core::Device::new_cuda(0).unwrap();
1634 let a = Tensor::from_vec(
1635 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1636 &[2, 4],
1637 &device,
1638 )
1639 .unwrap();
1640 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1641 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1642 }
1643
1644 #[test]
1645 fn test_bitwise_and_cpu() {
1646 use crate::utils::ops::BitWiseOp;
1647 use candle_core::Tensor;
1648 let device = candle_core::Device::Cpu;
1649 let a =
1650 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1651 let b =
1652 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1653 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1654 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
1655 }
1656
1657 #[cfg(feature = "cuda")]
1658 #[test]
1659 fn test_bitwise_and_cuda() {
1660 use crate::utils::ops::BitWiseOp;
1661 use candle_core::Tensor;
1662 let device = candle_core::Device::new_cuda(0).unwrap();
1663 let a =
1664 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1665 let b =
1666 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
1667 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1668 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
1669 }
1670
1671 #[test]
1672 fn test_bitwise_or_cpu() {
1673 use crate::utils::ops::BitWiseOp;
1674 use candle_core::Tensor;
1675 let device = candle_core::Device::Cpu;
1676 let a =
1677 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1678 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1679 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1680 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1681 }
1682
1683 #[cfg(feature = "cuda")]
1684 #[test]
1685 fn test_bitwise_or_cuda() {
1686 use crate::utils::ops::BitWiseOp;
1687 use candle_core::Tensor;
1688 let device = candle_core::Device::new_cuda(0).unwrap();
1689 let a =
1690 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1691 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1692 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1693 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1694 }
1695
1696 #[test]
1697 fn test_bitwise_xor_cpu() {
1698 use crate::utils::ops::BitWiseOp;
1699 use candle_core::Tensor;
1700 let device = candle_core::Device::Cpu;
1701 let a =
1702 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1703 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1704 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1705 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1706 }
1707
1708 #[cfg(feature = "cuda")]
1709 #[test]
1710 fn test_bitwise_xor_cuda() {
1711 use crate::utils::ops::BitWiseOp;
1712 use candle_core::Tensor;
1713 let device = candle_core::Device::new_cuda(0).unwrap();
1714 let a =
1715 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1716 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1717 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1718 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1719 }
1720
1721 #[test]
1722 fn test_nonzero_and() {
1723 use crate::utils::ops::{BitWiseOp, NonZeroOp};
1724 use candle_core::{Device, Tensor};
1725
1726 let input1 = Tensor::from_vec(
1727 vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1728 (10,),
1729 &Device::Cpu,
1730 )
1731 .unwrap();
1732 let input2 = Tensor::from_vec(
1733 vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1734 (10,),
1735 &Device::Cpu,
1736 )
1737 .unwrap();
1738 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1739
1740 let lt = input.lt(0.0).unwrap();
1741 let gt = input.gt(-10.0).unwrap();
1742 let res = lt
1743 .bitwise_and(>)
1744 .unwrap()
1745 .nonzero()
1746 .unwrap()
1747 .to_vec2::<u32>()
1748 .unwrap();
1749
1750 assert_eq!(
1751 res,
1752 [
1753 [0, 3],
1754 [0, 4],
1755 [0, 5],
1756 [0, 6],
1757 [1, 0],
1758 [1, 3],
1759 [1, 5],
1760 [1, 6]
1761 ]
1762 );
1763 }
1764
1765 #[cfg(feature = "cuda")]
1766 #[test]
1767 fn nonzero_and_cuda() {
1768 use crate::utils::ops::{BitWiseOp, NonZeroOp};
1769 use candle_core::{Device, Tensor};
1770
1771 let device = Device::new_cuda(0).unwrap();
1772 let input1 =
1773 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1774 let input2 =
1775 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1776 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1777
1778 let lt = input.lt(0.0).unwrap();
1779 let gt = input.gt(-10.0).unwrap();
1780 let res = lt
1781 .bitwise_and(>)
1782 .unwrap()
1783 .nonzero()
1784 .unwrap()
1785 .to_vec2::<u32>()
1786 .unwrap();
1787
1788 assert_eq!(
1789 res,
1790 [
1791 [0, 3],
1792 [0, 4],
1793 [0, 5],
1794 [0, 6],
1795 [1, 0],
1796 [1, 3],
1797 [1, 5],
1798 [1, 6]
1799 ]
1800 );
1801 }
1802
1803 #[test]
1804 fn test_bitpack_8bit_cpu() {
1805 use crate::HqqBits;
1806 use candle_core::{Device, Tensor};
1807 let bits = HqqBits::Eight;
1808 let device = Device::Cpu;
1809 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1810 let c = bits.bitpack_type()(wq.clone())
1811 .unwrap()
1812 .to_vec2::<u8>()
1813 .unwrap();
1814 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1815 }
1816
1817 #[cfg(feature = "cuda")]
1818 #[test]
1819 fn test_bitpack_8bit_cuda() {
1820 use crate::HqqBits;
1821 use candle_core::DType;
1822 use candle_core::{Device, Tensor};
1823 let bits = HqqBits::Eight;
1824 let device = Device::new_cuda(0).unwrap();
1825 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1826 let c = bits.bitpack_type()(wq.clone())
1827 .unwrap()
1828 .to_dtype(DType::U8)
1829 .unwrap()
1830 .to_vec2::<u8>()
1831 .unwrap();
1832 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1833 }
1834
1835 #[cfg(feature = "metal")]
1836 #[test]
1837 fn test_bitpack_8bit_metal() {
1838 use crate::HqqBits;
1839 use candle_core::{Device, Tensor};
1840 let bits = HqqBits::Eight;
1841 let device = Device::new_metal(0).unwrap();
1842 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1843 let c = bits.bitpack_type()(wq.clone())
1844 .unwrap()
1845 .to_vec2::<u8>()
1846 .unwrap();
1847 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1848 }
1849
1850 #[test]
1851 fn test_bitpack_4bit() {
1852 use crate::HqqBits;
1853 use candle_core::{Device, Tensor};
1854 let bits = HqqBits::Four;
1855 let device = Device::Cpu;
1856 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1857 let c = bits.bitpack_type()(wq.clone())
1858 .unwrap()
1859 .to_vec2::<u8>()
1860 .unwrap();
1861 assert_eq!(c, [[19, 36]]);
1862 }
1863
1864 #[cfg(feature = "cuda")]
1865 #[test]
1866 fn test_bitpack_4bit_cuda() {
1867 use crate::HqqBits;
1868 use candle_core::{Device, Tensor};
1869 let bits = HqqBits::Four;
1870 let device = Device::new_cuda(0).unwrap();
1871 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1872 let c = bits.bitpack_type()(wq.clone())
1873 .unwrap()
1874 .to_vec2::<u8>()
1875 .unwrap();
1876 assert_eq!(c, [[19, 36]]);
1877 }
1878
1879 #[cfg(feature = "metal")]
1880 #[test]
1881 fn test_bitpack_4bit_metal() {
1882 use crate::HqqBits;
1883 use candle_core::{Device, Tensor};
1884 let bits = HqqBits::Four;
1885 let device = Device::new_metal(0).unwrap();
1886 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1887 let c = bits.bitpack_type()(wq.clone())
1888 .unwrap()
1889 .to_vec2::<u8>()
1890 .unwrap();
1891 assert_eq!(c, [[19, 36]]);
1892 }
1893 #[cfg(feature = "metal")]
1895 #[test]
1896 fn test_sort_and_argsort_vector_metal() {
1897 use crate::utils::ops::SortOp;
1898 use candle_core::Tensor;
1899
1900 let device = candle_core::Device::new_metal(0).unwrap();
1901 let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
1902
1903 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1905 assert_eq!(sorted, [1, 2, 3, 4]);
1906
1907 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1909 assert_eq!(idx, [1, 3, 0, 2]);
1910 }
1911
1912 #[cfg(feature = "metal")]
1913 #[test]
1914 fn test_sort_and_argsort_matrix_axis1_metal() {
1915 use crate::utils::ops::SortOp;
1916 use candle_core::Tensor;
1917
1918 let device = candle_core::Device::new_metal(0).unwrap();
1919 let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
1923
1924 let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
1926 assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
1927
1928 let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
1930 assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
1931 }
1932
1933 #[cfg(feature = "metal")]
1935 #[test]
1936 fn test_sort_and_argsort_vector_2048_metal() {
1937 use crate::utils::ops::SortOp;
1938 use candle_core::Tensor;
1939
1940 const N: usize = 4096;
1941
1942 let device = candle_core::Device::new_metal(0).expect("Metal device");
1943
1944 let vals: Vec<i32> = (0..N as i32).rev().collect();
1946 let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
1947
1948 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1950 let expected: Vec<i32> = (0..N as i32).collect();
1951 assert_eq!(sorted, expected);
1952
1953 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1955 for (i, &v) in idx.iter().enumerate() {
1957 assert_eq!(v as usize, N - 1 - i);
1958 }
1959 }
1960}