1use candle_core::{
2 backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3 Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8 fmt::Display,
9 ops::{BitAnd, BitOr, BitXor, Not, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::ffi;
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19#[cfg(feature = "metal")]
20use crate::metal_kernels::SortScratchCache; #[cfg(feature = "metal")]
22use std::sync::OnceLock;
23
24#[cfg(feature = "metal")]
25static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
26
27struct Leftshift(usize);
28
29impl Leftshift {
30 fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
31 let offset = T::from_f64(self.0 as f64);
32 vs.into_par_iter().map(|v| *v << offset).collect()
33 }
34}
35
36impl CustomOp1 for Leftshift {
37 fn name(&self) -> &'static str {
38 "left"
39 }
40
41 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
42 match s1 {
43 CpuStorage::U8(vs1) => {
44 let vs1 = match l1.contiguous_offsets() {
45 Some((a, b)) => &vs1[a..b],
46 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
47 };
48 let result = self.leftshift(vs1);
49 let result = CpuStorage::U8(result);
50 Ok((result, l1.shape().clone()))
51 }
52 CpuStorage::I16(vs1) => {
53 let vs1 = match l1.contiguous_offsets() {
54 Some((a, b)) => &vs1[a..b],
55 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
56 };
57 let result = self.leftshift(vs1);
58 let result = CpuStorage::I16(result);
59 Ok((result, l1.shape().clone()))
60 }
61 CpuStorage::U32(vs1) => {
62 let vs1 = match l1.contiguous_offsets() {
63 Some((a, b)) => &vs1[a..b],
64 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
65 };
66 let result = self.leftshift(vs1);
67 let result = CpuStorage::U32(result);
68 Ok((result, l1.shape().clone()))
69 }
70 CpuStorage::I64(vs1) => {
71 let vs1 = match l1.contiguous_offsets() {
72 Some((a, b)) => &vs1[a..b],
73 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
74 };
75 let result = self.leftshift(vs1);
76 let result = CpuStorage::I64(result);
77 Ok((result, l1.shape().clone()))
78 }
79 CpuStorage::I32(vs1) => {
80 let vs1 = match l1.contiguous_offsets() {
81 Some((a, b)) => &vs1[a..b],
82 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
83 };
84 let result = self.leftshift(vs1);
85 let result = CpuStorage::I32(result);
86 Ok((result, l1.shape().clone()))
87 }
88 CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshifr")),
89 CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")),
90 CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")),
91 CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshifr")),
92 CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshifr")),
93 }
94 }
95 #[cfg(feature = "cuda")]
96 fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
97 if !l1.is_contiguous() {
98 candle_core::bail!("Input tensor s1 must be contiguous");
99 }
100 let dev = s1.device().clone();
101 let (d_in1_ptr, elem_count) = match s1.dtype() {
102 DType::U8 => {
103 let d_in1_ptr = *s1
104 .as_cuda_slice::<u8>()?
105 .slice(l1.start_offset()..)
106 .device_ptr() as *const c_void;
107 let elem_count = l1.shape().elem_count();
108 (d_in1_ptr, elem_count)
109 }
110 DType::I16 => {
111 return Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshift"));
112 }
113 DType::U32 => {
114 return Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshift"));
115 }
116 DType::I64 => {
117 return Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshift"));
118 }
119 DType::I32 => {
120 let d_in1_ptr = *s1
121 .as_cuda_slice::<i32>()?
122 .slice(l1.start_offset()..)
123 .device_ptr() as *const c_void;
124 let elem_count = l1.shape().elem_count();
125 (d_in1_ptr, elem_count)
126 }
127 DType::BF16 => {
128 return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift"));
129 }
130 DType::F16 => {
131 return Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshift"));
132 }
133 DType::F32 => {
134 return Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshift"));
135 }
136 DType::F64 => {
137 return Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshift"));
138 }
139 DType::F8E4M3 => {
140 return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshift"));
141 }
142 };
143 let dst = match s1.dtype() {
144 DType::U8 => {
145 let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
146 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
147 unsafe {
148 ffi::leftshift_u8(
149 d_in1_ptr,
150 d_out_ptr,
151 u32::try_from(elem_count)?,
152 self.0 as i32,
153 )
154 };
155 CudaStorage::wrap_cuda_slice(d_out, dev)
156 }
157 DType::I32 => {
158 let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
159 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
160 unsafe {
161 ffi::leftshift_i32(
162 d_in1_ptr,
163 d_out_ptr,
164 u32::try_from(elem_count)?,
165 self.0 as i32,
166 )
167 };
168 CudaStorage::wrap_cuda_slice(d_out, dev)
169 }
170 _ => unreachable!(),
171 };
172 Ok((dst, l1.shape().clone()))
173 }
174 #[cfg(feature = "metal")]
175 fn metal_fwd(
176 &self,
177 s1: &candle_core::MetalStorage,
178 l1: &Layout,
179 ) -> Result<(candle_core::MetalStorage, Shape)> {
180 if !l1.is_contiguous() {
181 candle_core::bail!("Input tensor s1 must be contiguous");
182 }
183
184 let command_buffer = s1.device().command_buffer()?;
185 command_buffer.set_label("bitwise-leftshift");
186
187 let device = s1.device();
188
189 let out_shape = l1.shape().clone();
190
191 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
192
193 crate::metal_kernels::call_bitwise_leftshift(
194 device.device(),
195 &command_buffer,
196 &crate::metal_kernels::Kernels::new(),
197 s1.dtype(),
198 s1.buffer(),
199 l1.start_offset(),
200 self.0 as u32,
201 out_shape.elem_count(),
202 &output,
203 )
204 .map_err(candle_core::Error::wrap)?;
205
206 let newstorage = candle_core::MetalStorage::new(
207 output,
208 device.clone(),
209 out_shape.elem_count(),
210 s1.dtype(),
211 );
212 Ok((newstorage, out_shape))
213 }
214}
215
216#[allow(dead_code)]
217pub trait LeftshiftOp {
218 fn leftshift(&self, n: usize) -> Result<Tensor>;
219}
220
221impl LeftshiftOp for Tensor {
222 fn leftshift(&self, n: usize) -> Result<Tensor> {
223 self.apply_op1_no_bwd(&Leftshift(n))
224 }
225}
226
227pub enum BitWiseBinaryOpEnum {
228 And,
229 Or,
230 Xor,
231}
232
233impl Display for BitWiseBinaryOpEnum {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 match self {
236 BitWiseBinaryOpEnum::And => write!(f, "And"),
237 BitWiseBinaryOpEnum::Or => write!(f, "Or"),
238 BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
239 }
240 }
241}
242
243pub enum BitWiseUnaryOpEnum {
244 Not,
245}
246
247impl Display for BitWiseUnaryOpEnum {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 match self {
250 BitWiseUnaryOpEnum::Not => write!(f, "Not"),
251 }
252 }
253}
254
255struct BitWise {
256 pub op: BitWiseBinaryOpEnum,
257}
258
259impl BitWise {
260 pub fn new(op: BitWiseBinaryOpEnum) -> Self {
261 Self { op }
262 }
263
264 fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
265 &self,
266 vs1: &[T],
267 vs2: &[T],
268 ) -> Vec<T> {
269 vs1.into_par_iter()
270 .zip_eq(vs2)
271 .map(|(v1, v2)| match self.op {
272 BitWiseBinaryOpEnum::And => *v1 & *v2,
273 BitWiseBinaryOpEnum::Or => *v1 | *v2,
274 BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
275 })
276 .collect()
277 }
278}
279
280impl CustomOp2 for BitWise {
281 fn name(&self) -> &'static str {
282 "bitwise"
283 }
284
285 fn cpu_fwd(
286 &self,
287 s1: &CpuStorage,
288 l1: &Layout,
289 s2: &CpuStorage,
290 l2: &Layout,
291 ) -> Result<(CpuStorage, Shape)> {
292 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
293 return Err(Error::ShapeMismatchBinaryOp {
294 lhs: l1.shape().clone(),
295 rhs: l2.shape().clone(),
296 op: "bitwise-op",
297 });
298 }
299 if s1.dtype() != s2.dtype() {
300 return Err(Error::DTypeMismatchBinaryOp {
301 lhs: s1.dtype(),
302 rhs: s2.dtype(),
303 op: "bitwise-op",
304 });
305 }
306 if !l1.is_contiguous() {
307 candle_core::bail!("Input tensor s1 must be contiguous");
308 }
309 if !l2.is_contiguous() {
310 candle_core::bail!("Input tensor s2 must be contiguous");
311 }
312
313 match s1 {
314 CpuStorage::U8(vs1) => {
315 let vs2 = s2.as_slice::<u8>().unwrap();
316 let vs1 = match l1.contiguous_offsets() {
317 Some((a, b)) => &vs1[a..b],
318 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
319 };
320 let vs2 = match l2.contiguous_offsets() {
321 Some((a, b)) => &vs2[a..b],
322 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
323 };
324 let result = self.bitwise(vs1, vs2);
325 let result = CpuStorage::U8(result);
326 Ok((result, l1.shape().clone()))
327 }
328 CpuStorage::U32(vs1) => {
329 let vs2 = s2.as_slice::<u32>().unwrap();
330 let vs1 = match l1.contiguous_offsets() {
331 Some((a, b)) => &vs1[a..b],
332 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
333 };
334 let vs2 = match l2.contiguous_offsets() {
335 Some((a, b)) => &vs2[a..b],
336 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
337 };
338 let result = self.bitwise(vs1, vs2);
339 let result = CpuStorage::U32(result);
340 Ok((result, l1.shape().clone()))
341 }
342 CpuStorage::I64(vs1) => {
343 let vs2 = s2.as_slice::<i64>().unwrap();
344 let vs1 = match l1.contiguous_offsets() {
345 Some((a, b)) => &vs1[a..b],
346 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
347 };
348 let vs2 = match l2.contiguous_offsets() {
349 Some((a, b)) => &vs2[a..b],
350 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
351 };
352 let result = self.bitwise(vs1, vs2);
353 let result = CpuStorage::I64(result);
354 Ok((result, l1.shape().clone()))
355 }
356 CpuStorage::I16(vs1) => {
357 let vs2 = s2.as_slice::<i16>().unwrap();
358 let vs1 = match l1.contiguous_offsets() {
359 Some((a, b)) => &vs1[a..b],
360 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
361 };
362 let vs2 = match l2.contiguous_offsets() {
363 Some((a, b)) => &vs2[a..b],
364 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
365 };
366 let result = self.bitwise(vs1, vs2);
367 let result = CpuStorage::I16(result);
368 Ok((result, l1.shape().clone()))
369 }
370 CpuStorage::I32(vs1) => {
371 let vs2 = s2.as_slice::<i32>().unwrap();
372 let vs1 = match l1.contiguous_offsets() {
373 Some((a, b)) => &vs1[a..b],
374 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
375 };
376 let vs2 = match l2.contiguous_offsets() {
377 Some((a, b)) => &vs2[a..b],
378 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
379 };
380 let result = self.bitwise(vs1, vs2);
381 let result = CpuStorage::I32(result);
382 Ok((result, l1.shape().clone()))
383 }
384 CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")),
385 CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")),
386 CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")),
387 CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")),
388 CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")),
389 }
390 }
391
392 #[cfg(feature = "cuda")]
393 fn cuda_fwd(
394 &self,
395 s1: &CudaStorage,
396 l1: &Layout,
397 s2: &CudaStorage,
398 l2: &Layout,
399 ) -> Result<(CudaStorage, Shape)> {
400 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
401 return Err(Error::ShapeMismatchBinaryOp {
402 lhs: l1.shape().clone(),
403 rhs: l2.shape().clone(),
404 op: "bitwise-op",
405 });
406 }
407 if s1.dtype() != s2.dtype() {
408 return Err(Error::DTypeMismatchBinaryOp {
409 lhs: s1.dtype(),
410 rhs: s2.dtype(),
411 op: "bitwise-op",
412 });
413 }
414 if !l1.is_contiguous() {
415 candle_core::bail!("Input tensor s1 must be contiguous");
416 }
417 if !l2.is_contiguous() {
418 candle_core::bail!("Input tensor s2 must be contiguous");
419 }
420
421 let dev = s1.device().clone();
422 let (d_in1_ptr, d_in2_ptr, elem_count) = match s1.dtype() {
423 DType::U8 => {
424 let d_in1_ptr = *s1
425 .as_cuda_slice::<u8>()?
426 .slice(l1.start_offset()..)
427 .device_ptr() as *const c_void;
428 let d_in2_ptr = *s2
429 .as_cuda_slice::<u8>()?
430 .slice(l2.start_offset()..)
431 .device_ptr() as *const c_void;
432 let elem_count = l1.shape().elem_count();
433 (d_in1_ptr, d_in2_ptr, elem_count)
434 }
435 DType::U32 => {
436 let d_in1_ptr = *s1
437 .as_cuda_slice::<u32>()?
438 .slice(l1.start_offset()..)
439 .device_ptr() as *const c_void;
440 let d_in2_ptr = *s2
441 .as_cuda_slice::<u32>()?
442 .slice(l2.start_offset()..)
443 .device_ptr() as *const c_void;
444 let elem_count = l1.shape().elem_count();
445 (d_in1_ptr, d_in2_ptr, elem_count)
446 }
447 DType::I64 => {
448 let d_in1_ptr = *s1
449 .as_cuda_slice::<i64>()?
450 .slice(l1.start_offset()..)
451 .device_ptr() as *const c_void;
452 let d_in2_ptr = *s2
453 .as_cuda_slice::<i64>()?
454 .slice(l2.start_offset()..)
455 .device_ptr() as *const c_void;
456 let elem_count = l1.shape().elem_count();
457 (d_in1_ptr, d_in2_ptr, elem_count)
458 }
459 DType::I32 => {
460 let d_in1_ptr = *s1
461 .as_cuda_slice::<i32>()?
462 .slice(l1.start_offset()..)
463 .device_ptr() as *const c_void;
464 let d_in2_ptr = *s2
465 .as_cuda_slice::<i32>()?
466 .slice(l2.start_offset()..)
467 .device_ptr() as *const c_void;
468 let elem_count = l1.shape().elem_count();
469 (d_in1_ptr, d_in2_ptr, elem_count)
470 }
471 DType::I16 => {
472 let d_in1_ptr = *s1
473 .as_cuda_slice::<i16>()?
474 .slice(l1.start_offset()..)
475 .device_ptr() as *const c_void;
476 let d_in2_ptr = *s2
477 .as_cuda_slice::<i16>()?
478 .slice(l2.start_offset()..)
479 .device_ptr() as *const c_void;
480 let elem_count = l1.shape().elem_count();
481 (d_in1_ptr, d_in2_ptr, elem_count)
482 }
483 DType::BF16 => {
484 return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise"));
485 }
486 DType::F16 => {
487 return Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise"));
488 }
489 DType::F32 => {
490 return Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise"));
491 }
492 DType::F64 => {
493 return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise"));
494 }
495 DType::F8E4M3 => {
496 return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise"));
497 }
498 };
499 let dst = match s1.dtype() {
500 DType::U8 => {
501 let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
502 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
503 unsafe {
504 match self.op {
505 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
506 d_in1_ptr,
507 d_in2_ptr,
508 d_out_ptr,
509 u32::try_from(elem_count)?,
510 ),
511 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
512 d_in1_ptr,
513 d_in2_ptr,
514 d_out_ptr,
515 u32::try_from(elem_count)?,
516 ),
517 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
518 d_in1_ptr,
519 d_in2_ptr,
520 d_out_ptr,
521 u32::try_from(elem_count)?,
522 ),
523 }
524 };
525 CudaStorage::wrap_cuda_slice(d_out, dev)
526 }
527 DType::U32 => {
528 let d_out = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
529 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
530 unsafe {
531 match self.op {
532 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
533 d_in1_ptr,
534 d_in2_ptr,
535 d_out_ptr,
536 u32::try_from(elem_count)?,
537 ),
538 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
539 d_in1_ptr,
540 d_in2_ptr,
541 d_out_ptr,
542 u32::try_from(elem_count)?,
543 ),
544 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
545 d_in1_ptr,
546 d_in2_ptr,
547 d_out_ptr,
548 u32::try_from(elem_count)?,
549 ),
550 }
551 };
552 CudaStorage::wrap_cuda_slice(d_out, dev)
553 }
554 DType::I64 => {
555 let d_out = unsafe { dev.alloc::<i64>(elem_count) }.w()?;
556 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
557 unsafe {
558 match self.op {
559 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
560 d_in1_ptr,
561 d_in2_ptr,
562 d_out_ptr,
563 u32::try_from(elem_count)?,
564 ),
565 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
566 d_in1_ptr,
567 d_in2_ptr,
568 d_out_ptr,
569 u32::try_from(elem_count)?,
570 ),
571 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
572 d_in1_ptr,
573 d_in2_ptr,
574 d_out_ptr,
575 u32::try_from(elem_count)?,
576 ),
577 }
578 };
579 CudaStorage::wrap_cuda_slice(d_out, dev)
580 }
581 DType::I32 => {
582 let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
583 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
584 unsafe {
585 match self.op {
586 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
587 d_in1_ptr,
588 d_in2_ptr,
589 d_out_ptr,
590 u32::try_from(elem_count)?,
591 ),
592 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
593 d_in1_ptr,
594 d_in2_ptr,
595 d_out_ptr,
596 u32::try_from(elem_count)?,
597 ),
598 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
599 d_in1_ptr,
600 d_in2_ptr,
601 d_out_ptr,
602 u32::try_from(elem_count)?,
603 ),
604 }
605 };
606 CudaStorage::wrap_cuda_slice(d_out, dev)
607 }
608 _ => unreachable!(),
609 };
610 Ok((dst, l1.shape().clone()))
611 }
612
613 #[cfg(feature = "metal")]
614 fn metal_fwd(
615 &self,
616 s1: &candle_core::MetalStorage,
617 l1: &Layout,
618 s2: &candle_core::MetalStorage,
619 l2: &Layout,
620 ) -> Result<(candle_core::MetalStorage, Shape)> {
621 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
622 return Err(Error::ShapeMismatchBinaryOp {
623 lhs: l1.shape().clone(),
624 rhs: l2.shape().clone(),
625 op: "bitwise-op",
626 });
627 }
628 if s1.dtype() != s2.dtype() {
629 return Err(Error::DTypeMismatchBinaryOp {
630 lhs: s1.dtype(),
631 rhs: s2.dtype(),
632 op: "bitwise-op",
633 });
634 }
635 if !l1.is_contiguous() {
636 candle_core::bail!("Input tensor s1 must be contiguous");
637 }
638 if !l2.is_contiguous() {
639 candle_core::bail!("Input tensor s2 must be contiguous");
640 }
641
642 let command_buffer = s1.device().command_buffer()?;
643 command_buffer.set_label("bitwise-op");
644
645 let device = s1.device();
646
647 let out_shape = l1.shape().clone();
648
649 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
650
651 match self.op {
652 BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
653 device.device(),
654 &command_buffer,
655 &crate::metal_kernels::Kernels::new(),
656 s1.dtype(),
657 s1.buffer(),
658 s2.buffer(),
659 l1.start_offset() * s1.dtype().size_in_bytes(),
660 l2.start_offset() * s2.dtype().size_in_bytes(),
661 out_shape.elem_count(),
662 &output,
663 )
664 .map_err(candle_core::Error::wrap)?,
665 BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
666 device.device(),
667 &command_buffer,
668 &crate::metal_kernels::Kernels::new(),
669 s1.dtype(),
670 s1.buffer(),
671 s2.buffer(),
672 l1.start_offset() * s1.dtype().size_in_bytes(),
673 l2.start_offset() * s2.dtype().size_in_bytes(),
674 out_shape.elem_count(),
675 &output,
676 )
677 .map_err(candle_core::Error::wrap)?,
678 BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
679 device.device(),
680 &command_buffer,
681 &crate::metal_kernels::Kernels::new(),
682 s1.dtype(),
683 s1.buffer(),
684 s2.buffer(),
685 l1.start_offset() * s1.dtype().size_in_bytes(),
686 l2.start_offset() * s2.dtype().size_in_bytes(),
687 out_shape.elem_count(),
688 &output,
689 )
690 .map_err(candle_core::Error::wrap)?,
691 }
692
693 let newstorage = candle_core::MetalStorage::new(
694 output,
695 device.clone(),
696 out_shape.elem_count(),
697 s1.dtype(),
698 );
699 Ok((newstorage, out_shape))
700 }
701}
702
703struct BitWiseUnary {
704 pub op: BitWiseUnaryOpEnum,
705}
706
707impl BitWiseUnary {
708 pub fn new(op: BitWiseUnaryOpEnum) -> Self {
709 Self { op }
710 }
711
712 fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
713 vs1.into_par_iter()
714 .map(|v1| match self.op {
715 BitWiseUnaryOpEnum::Not => !*v1,
716 })
717 .collect()
718 }
719}
720
721impl CustomOp1 for BitWiseUnary {
722 fn name(&self) -> &'static str {
723 "bitwise-unary"
724 }
725
726 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
727 if !l1.is_contiguous() {
728 candle_core::bail!("Input tensor s1 must be contiguous");
729 }
730
731 match s1 {
732 CpuStorage::U8(vs1) => {
733 let vs1 = match l1.contiguous_offsets() {
734 Some((a, b)) => &vs1[a..b],
735 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
736 };
737 let result = self.bitwise(vs1);
738 let result = CpuStorage::U8(result);
739 Ok((result, l1.shape().clone()))
740 }
741 CpuStorage::U32(vs1) => {
742 let vs1 = match l1.contiguous_offsets() {
743 Some((a, b)) => &vs1[a..b],
744 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
745 };
746 let result = self.bitwise(vs1);
747 let result = CpuStorage::U32(result);
748 Ok((result, l1.shape().clone()))
749 }
750 CpuStorage::I64(vs1) => {
751 let vs1 = match l1.contiguous_offsets() {
752 Some((a, b)) => &vs1[a..b],
753 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
754 };
755 let result = self.bitwise(vs1);
756 let result = CpuStorage::I64(result);
757 Ok((result, l1.shape().clone()))
758 }
759 CpuStorage::I16(vs1) => {
760 let vs1 = match l1.contiguous_offsets() {
761 Some((a, b)) => &vs1[a..b],
762 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
763 };
764 let result = self.bitwise(vs1);
765 let result = CpuStorage::I16(result);
766 Ok((result, l1.shape().clone()))
767 }
768 CpuStorage::I32(vs1) => {
769 let vs1 = match l1.contiguous_offsets() {
770 Some((a, b)) => &vs1[a..b],
771 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
772 };
773 let result = self.bitwise(vs1);
774 let result = CpuStorage::I32(result);
775 Ok((result, l1.shape().clone()))
776 }
777 CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")),
778 CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")),
779 CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")),
780 CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")),
781 CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")),
782 }
783 }
784
785 #[cfg(feature = "cuda")]
786 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
787 todo!()
788 }
789
790 #[cfg(feature = "metal")]
791 fn metal_fwd(
792 &self,
793 s1: &candle_core::MetalStorage,
794 l1: &Layout,
795 ) -> Result<(candle_core::MetalStorage, Shape)> {
796 if !l1.is_contiguous() {
797 candle_core::bail!("Input tensor s1 must be contiguous");
798 }
799
800 let command_buffer = s1.device().command_buffer()?;
801 command_buffer.set_label("bitwise-unary-op");
802
803 let device = s1.device();
804
805 let out_shape = l1.shape().clone();
806
807 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
808
809 match self.op {
810 BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
811 device.device(),
812 &command_buffer,
813 &crate::metal_kernels::Kernels::new(),
814 s1.dtype(),
815 s1.buffer(),
816 l1.start_offset() * s1.dtype().size_in_bytes(),
817 out_shape.elem_count(),
818 &output,
819 )
820 .map_err(candle_core::Error::wrap)?,
821 }
822
823 let newstorage = candle_core::MetalStorage::new(
824 output,
825 device.clone(),
826 out_shape.elem_count(),
827 s1.dtype(),
828 );
829 Ok((newstorage, out_shape))
830 }
831}
832
833#[allow(dead_code)]
834pub trait BitWiseOp {
835 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
836 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
837 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
838 fn bitwise_not(&self) -> Result<Tensor>;
839}
840
841impl BitWiseOp for Tensor {
842 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
843 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
844 }
845
846 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
847 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
848 }
849
850 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
851 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
852 }
853
854 fn bitwise_not(&self) -> Result<Tensor> {
855 self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
856 }
857}
858
859#[allow(unused)]
862struct ArgSort {
864 axis: usize,
865}
866
867#[allow(unused)]
868struct Sort {
870 axis: usize,
871}
872
873impl CustomOp1 for ArgSort {
874 fn name(&self) -> &'static str {
875 "argsort"
876 }
877
878 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
880 candle_core::bail!("ArgSort is not implemented for the CPU backend");
881 }
882
883 #[cfg(feature = "cuda")]
885 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
886 candle_core::bail!("ArgSort is not implemented for the CUDA backend");
887 }
888
889 #[cfg(feature = "metal")]
891 fn metal_fwd(
892 &self,
893 s1: &candle_core::MetalStorage,
894 l1: &Layout,
895 ) -> Result<(candle_core::MetalStorage, Shape)> {
896 if !l1.is_contiguous() {
898 candle_core::bail!("Input tensor s1 must be contiguous");
899 }
900
901 let command_buffer = s1.device().command_buffer()?;
903 command_buffer.set_label("argsort");
904
905 let device = s1.device();
906 let out_shape = l1.shape().clone();
907 let elem_count = out_shape.elem_count();
908
909 let output = device.new_buffer(elem_count, candle_core::DType::U32, "argsort")?;
911
912 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
916
917 let dims = l1.dims();
918 let size_sorted_axis = dims[self.axis];
919 let n_rows = l1.shape().elem_count() / size_sorted_axis;
920
921 let tn = 4usize;
923 let mut bn = match size_sorted_axis.div_ceil(tn) {
924 v if v > 256 => 512,
925 v if v > 128 => 256,
926 v if v > 64 => 128,
927 v if v > 32 => 64,
928 _ => 32,
929 };
930 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
931 bn = 256;
932 }
933 let n_per_block = bn * tn;
934 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
935
936 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
938
939 let sort_args = crate::metal_kernels::SortArgs {
943 axis: self.axis,
944 shape: l1.dims(),
945 strides: l1.stride(),
946 out_shape: l1.dims(), out_strides: l1.stride(),
948 in_contiguous: l1.is_contiguous(),
949 in_ty: s1.dtype(),
950 out_ty: candle_core::DType::U32,
951 src: s1.buffer(),
952 src_offset: l1.start_offset(), dst: &output,
954 bn,
955 tn,
956 n_blocks,
957 };
958
959 crate::metal_kernels::call_argsort(
961 device.device(), &command_buffer, &crate::metal_kernels::Kernels::new(),
964 &sort_args,
965 &scratch,
966 )
967 .map_err(candle_core::Error::wrap)?;
968
969 let newstorage = candle_core::MetalStorage::new(
971 output,
972 device.clone(),
973 elem_count,
974 candle_core::DType::U32,
975 );
976 Ok((newstorage, out_shape))
977 }
978}
979
980impl CustomOp1 for Sort {
981 fn name(&self) -> &'static str {
982 "sort"
983 }
984
985 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
987 candle_core::bail!("Sort is not implemented for the CPU backend");
988 }
989
990 #[cfg(feature = "cuda")]
992 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
993 candle_core::bail!("Sort is not implemented for the CUDA backend");
994 }
995
996 #[cfg(feature = "metal")]
998 fn metal_fwd(
999 &self,
1000 s1: &candle_core::MetalStorage,
1001 l1: &Layout,
1002 ) -> Result<(candle_core::MetalStorage, Shape)> {
1003 if !l1.is_contiguous() {
1005 candle_core::bail!("Input tensor s1 must be contiguous");
1006 }
1007
1008 let command_buffer = s1.device().command_buffer()?;
1010 command_buffer.set_label("sort");
1011
1012 let device = s1.device();
1013 let out_shape = l1.shape().clone();
1014 let elem_count = out_shape.elem_count();
1015
1016 let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
1018
1019 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
1023
1024 let dims = l1.dims();
1025 let size_sorted_axis = dims[self.axis];
1026 let n_rows = l1.shape().elem_count() / size_sorted_axis;
1027
1028 let tn = 4usize;
1030 let mut bn = match size_sorted_axis.div_ceil(tn) {
1031 v if v > 256 => 512,
1032 v if v > 128 => 256,
1033 v if v > 64 => 128,
1034 v if v > 32 => 64,
1035 _ => 32,
1036 };
1037 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
1038 bn = 256;
1039 }
1040 let n_per_block = bn * tn;
1041 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
1042
1043 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
1045
1046 let sort_args = crate::metal_kernels::SortArgs {
1050 axis: self.axis,
1051 shape: l1.dims(),
1052 strides: l1.stride(),
1053 out_shape: l1.dims(), out_strides: l1.stride(),
1055 in_contiguous: l1.is_contiguous(),
1056 in_ty: s1.dtype(),
1057 out_ty: s1.dtype(),
1058 src: s1.buffer(),
1059 src_offset: l1.start_offset(), dst: &output,
1061 bn,
1062 tn,
1063 n_blocks,
1064 };
1065
1066 crate::metal_kernels::call_sort(
1068 device.device(), &command_buffer, &crate::metal_kernels::Kernels::new(),
1071 &sort_args,
1072 &scratch,
1073 )
1074 .map_err(candle_core::Error::wrap)?;
1075
1076 let newstorage =
1078 candle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1079 Ok((newstorage, out_shape))
1080 }
1081}
1082
1083pub trait SortOp {
1085 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1087 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1089}
1090
1091impl SortOp for Tensor {
1092 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1093 if self.device().is_cpu() || self.device().is_cuda() {
1094 return self.arg_sort_last_dim(true);
1095 }
1096 self.apply_op1_no_bwd(&ArgSort {
1097 axis: axis.to_index(self.shape(), "argsort")?,
1098 })
1099 }
1100
1101 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1102 if self.device().is_cpu() || self.device().is_cuda() {
1103 return Ok(self.sort_last_dim(true)?.0);
1104 }
1105 self.apply_op1_no_bwd(&Sort {
1106 axis: axis.to_index(self.shape(), "sort")?,
1107 })
1108 }
1109}
1110
1111struct NonZero;
1112
1113impl NonZero {
1114 fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1116 let n = layout.dims().len();
1117 let mut result = Vec::new();
1118 let mut indices = vec![0u32; n];
1119 for (i, v) in vs.iter().enumerate() {
1120 if !v.is_zero() {
1121 let mut idx = i;
1122 for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1123 let d = idx % dim;
1124 indices[dim_index] = u32::try_from(d).unwrap();
1125 idx /= dim;
1126 }
1127 result.extend_from_slice(&indices);
1128 }
1129 }
1130 result
1131 }
1132}
1133
1134#[cfg(feature = "cuda")]
1135fn count_nonzero_cuda(
1136 dtype: candle_core::DType,
1137 d_in: *const c_void,
1138 n: u32,
1139 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1140) -> u32 {
1141 unsafe {
1142 match dtype {
1143 candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1144 candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1145 candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1146 candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1147 candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1148 candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1149 candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1150 candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1151 candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1152 candle_core::DType::F8E4M3 => todo!(),
1153 }
1154 }
1155}
1156
1157#[allow(clippy::too_many_arguments)]
1158#[cfg(feature = "cuda")]
1159fn nonzero_cuda(
1160 dtype: candle_core::DType,
1161 d_in: *const c_void,
1162 n: u32,
1163 num_nonzero: u32,
1164 dims: *const c_void,
1165 num_dims: u32,
1166 d_out: *mut c_void,
1167 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1168) {
1169 unsafe {
1170 match dtype {
1171 candle_core::DType::U8 => {
1172 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1173 }
1174 candle_core::DType::U32 => {
1175 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1176 }
1177 candle_core::DType::I64 => {
1178 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1179 }
1180 candle_core::DType::I32 => {
1181 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1182 }
1183 candle_core::DType::I16 => {
1184 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1185 }
1186 candle_core::DType::BF16 => {
1187 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1188 }
1189 candle_core::DType::F16 => {
1190 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1191 }
1192 candle_core::DType::F32 => {
1193 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1194 }
1195 candle_core::DType::F64 => {
1196 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1197 }
1198 candle_core::DType::F8E4M3 => todo!(),
1199 }
1200 }
1201}
1202
1203impl CustomOp1 for NonZero {
1204 fn name(&self) -> &'static str {
1205 "nonzero"
1206 }
1207
1208 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1209 if !layout.is_contiguous() {
1210 return Err(Error::RequiresContiguous { op: "nonzero" });
1211 }
1212 let result = match storage {
1213 candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1214 candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1215 candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1216 candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1217 candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1218 candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1219 candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1220 candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1221 candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1222 candle_core::CpuStorage::F8E4M3(_vs) => todo!(),
1223 };
1224 let index_len = layout.dims().len();
1225 let result_len = result.len() / index_len;
1226 let result = CpuStorage::U32(result);
1227 let shape = Shape::from_dims(&[result_len, index_len]);
1228 Ok((result, shape))
1229 }
1230 #[cfg(feature = "cuda")]
1231 fn cuda_fwd(
1232 &self,
1233 storage: &candle_core::CudaStorage,
1234 layout: &Layout,
1235 ) -> Result<(candle_core::CudaStorage, Shape)> {
1236 if !layout.is_contiguous() {
1237 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1238 }
1239 let dev = storage.device().clone();
1240 let d_in = match storage.dtype() {
1241 candle_core::DType::U8 => *storage.as_cuda_slice::<u8>()?.device_ptr(),
1242 candle_core::DType::U32 => *storage.as_cuda_slice::<u32>()?.device_ptr(),
1243 candle_core::DType::I32 => *storage.as_cuda_slice::<i32>()?.device_ptr(),
1244 candle_core::DType::I16 => *storage.as_cuda_slice::<i16>()?.device_ptr(),
1245 candle_core::DType::I64 => *storage.as_cuda_slice::<i64>()?.device_ptr(),
1246 candle_core::DType::BF16 => *storage.as_cuda_slice::<half::bf16>()?.device_ptr(),
1247 candle_core::DType::F16 => *storage.as_cuda_slice::<half::f16>()?.device_ptr(),
1248 candle_core::DType::F32 => *storage.as_cuda_slice::<f32>()?.device_ptr(),
1249 candle_core::DType::F64 => *storage.as_cuda_slice::<f64>()?.device_ptr(),
1250 candle_core::DType::F8E4M3 => todo!(),
1251 } as *const c_void;
1252 let n = layout.shape().elem_count();
1253
1254 let num_nonzero =
1255 count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?, *dev.cu_stream());
1256 let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1257 .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1258 if num_nonzero != 0 {
1259 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
1260 let dims = layout
1261 .dims()
1262 .iter()
1263 .map(|&x| u32::try_from(x).unwrap())
1264 .collect::<Vec<u32>>();
1265 let d_dims = dev
1266 .htod_copy(dims)
1267 .map_err(|_| Error::Msg("Failed to copy dims to device".to_string()))?;
1268 let d_dims_ptr = *d_dims.device_ptr() as *const c_void;
1269 nonzero_cuda(
1270 storage.dtype(),
1271 d_in,
1272 u32::try_from(n)?,
1273 num_nonzero,
1274 d_dims_ptr,
1275 u32::try_from(layout.dims().len())?,
1276 d_out_ptr,
1277 *dev.cu_stream(),
1278 );
1279 }
1280 let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1281 let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
1282 Ok((dst, shape))
1283 }
1284}
1285
1286pub trait NonZeroOp {
1287 fn nonzero(&self) -> Result<Tensor>;
1288}
1289
1290impl NonZeroOp for Tensor {
1291 #[cfg(feature = "metal")]
1292 fn nonzero(&self) -> Result<Tensor> {
1293 if !self.is_contiguous() {
1294 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1295 }
1296 let original_device = self.device();
1297 self.to_device(&candle_core::Device::Cpu)?
1298 .apply_op1_no_bwd(&NonZero)?
1299 .to_device(original_device)
1300 }
1301
1302 #[cfg(not(feature = "metal"))]
1303 fn nonzero(&self) -> Result<Tensor> {
1304 if !self.is_contiguous() {
1305 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1306 }
1307 self.apply_op1_no_bwd(&NonZero)
1308 }
1309}
1310
1311struct CumSum {
1312 inclusive: bool,
1313 reverse: bool,
1314 axis: usize,
1315}
1316
1317impl CustomOp1 for CumSum {
1318 fn name(&self) -> &'static str {
1319 "cumsum"
1320 }
1321
1322 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1323 use std::ops::Add;
1324 if !l1.is_contiguous() {
1325 candle_core::bail!("Input tensor s1 must be contiguous");
1326 }
1327 let dims = l1.dims();
1328 let axis = self.axis;
1329 let axis_len = dims[axis];
1330 let (start, end) = l1
1331 .contiguous_offsets()
1332 .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1333
1334 macro_rules! scan_block {
1336 ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1337 let vs: &[$ty] = $vt;
1338 let input = &vs[start..end];
1339 let count = input.len() / axis_len;
1340 let mut result = Vec::<$ty>::with_capacity(input.len());
1341 if !self.reverse {
1342 if self.inclusive {
1343 for block in 0..count {
1344 let base = block * axis_len;
1345 let mut sum = input[base];
1346 result.push(sum);
1347 for j in 1..axis_len {
1348 sum = sum.$add(input[base + j]);
1349 result.push(sum);
1350 }
1351 }
1352 } else {
1353 let init: $ty = $init;
1354 for block in 0..count {
1355 let base = block * axis_len;
1356 let mut sum = init;
1357 for j in 0..axis_len {
1358 result.push(sum);
1359 sum = sum.$add(input[base + j]);
1360 }
1361 }
1362 }
1363 } else {
1364 if self.inclusive {
1365 for block in 0..count {
1366 let base = block * axis_len;
1367 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1368 let mut sum = input[base + axis_len - 1];
1369 temp.push(sum);
1370 for k in 1..axis_len {
1371 let idx = axis_len - 1 - k;
1372 sum = sum.$add(input[base + idx]);
1373 temp.push(sum);
1374 }
1375 temp.reverse();
1376 result.extend(temp);
1377 }
1378 } else {
1379 let init: $ty = $init;
1380 for block in 0..count {
1381 let base = block * axis_len;
1382 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1383 let mut sum = init;
1384 for k in 0..axis_len {
1385 let idx = axis_len - 1 - k;
1386 temp.push(sum);
1387 sum = sum.$add(input[base + idx]);
1388 }
1389 temp.reverse();
1390 result.extend(temp);
1391 }
1392 }
1393 }
1394 result
1395 }};
1396 }
1397 match s1 {
1398 CpuStorage::U8(vs) => {
1399 let result = scan_block!(vs, u8, wrapping_add, 0u8);
1400 Ok((CpuStorage::U8(result), l1.shape().clone()))
1401 }
1402 CpuStorage::I16(vs) => {
1403 let result = scan_block!(vs, i16, add, 0i16);
1404 Ok((CpuStorage::I16(result), l1.shape().clone()))
1405 }
1406 CpuStorage::U32(vs) => {
1407 let result = scan_block!(vs, u32, wrapping_add, 0u32);
1408 Ok((CpuStorage::U32(result), l1.shape().clone()))
1409 }
1410 CpuStorage::I32(vs) => {
1411 let result = scan_block!(vs, i32, add, 0i32);
1412 Ok((CpuStorage::I32(result), l1.shape().clone()))
1413 }
1414 CpuStorage::I64(vs) => {
1415 let result = scan_block!(vs, i64, add, 0i64);
1416 Ok((CpuStorage::I64(result), l1.shape().clone()))
1417 }
1418 CpuStorage::F32(vs) => {
1419 let result = scan_block!(vs, f32, add, 0.0f32);
1420 Ok((CpuStorage::F32(result), l1.shape().clone()))
1421 }
1422 CpuStorage::F64(vs) => {
1423 let result = scan_block!(vs, f64, add, 0.0f64);
1424 Ok((CpuStorage::F64(result), l1.shape().clone()))
1425 }
1426 _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1427 }
1428 }
1429
1430 #[cfg(feature = "cuda")]
1431 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1432 todo!()
1433 }
1434
1435 #[cfg(feature = "metal")]
1436 fn metal_fwd(
1437 &self,
1438 s1: &candle_core::MetalStorage,
1439 l1: &Layout,
1440 ) -> Result<(candle_core::MetalStorage, Shape)> {
1441 use crate::metal_kernels::ScanType;
1442
1443 let command_buffer = s1.device().command_buffer()?;
1444 command_buffer.set_label("cumsum");
1445
1446 let device = s1.device();
1447
1448 let out_shape = l1.shape().clone();
1449
1450 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1451
1452 crate::metal_kernels::call_scan(
1453 device.device(),
1454 &command_buffer,
1455 &crate::metal_kernels::Kernels::new(),
1456 s1.dtype(),
1457 ScanType::Sum,
1458 s1.buffer(),
1459 l1.start_offset() * s1.dtype().size_in_bytes(),
1460 self.axis,
1461 l1.dims(),
1462 l1.stride(),
1463 self.reverse,
1464 self.inclusive,
1465 &output,
1466 )
1467 .map_err(candle_core::Error::wrap)?;
1468
1469 let newstorage = candle_core::MetalStorage::new(
1470 output,
1471 device.clone(),
1472 out_shape.elem_count(),
1473 s1.dtype(),
1474 );
1475 Ok((newstorage, out_shape))
1476 }
1477}
1478
1479#[allow(dead_code)]
1480pub trait CumSumOp {
1481 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1483
1484 fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1485 -> Result<Tensor>;
1486}
1487
1488impl CumSumOp for Tensor {
1489 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1490 self.fast_cumsum_config(axis, false, false)
1491 }
1492
1493 fn fast_cumsum_config<D: Dim>(
1494 &self,
1495 axis: D,
1496 inclusive: bool,
1497 reverse: bool,
1498 ) -> Result<Tensor> {
1499 self.apply_op1_no_bwd(&CumSum {
1500 inclusive,
1501 reverse,
1502 axis: axis.to_index(self.shape(), "cumsum")?,
1503 })
1504 }
1505}
1506
1507mod tests {
1508 #[test]
1509 fn test_cumsum_exclusive_forward_cpu() {
1510 use crate::utils::ops::CumSumOp;
1511 use candle_core::Tensor;
1512 let device = candle_core::Device::Cpu;
1513 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1514 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1515 assert_eq!(b, [0, 1, 3, 6]);
1516 }
1517
1518 #[test]
1519 fn test_cumsum_inclusive_forward_cpu() {
1520 use crate::utils::ops::CumSumOp;
1521 use candle_core::Tensor;
1522 let device = candle_core::Device::Cpu;
1523 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1524 let b = a
1525 .fast_cumsum_config(0, true, false)
1526 .unwrap()
1527 .to_vec1::<i64>()
1528 .unwrap();
1529 assert_eq!(b, [1, 3, 6, 10]);
1530 }
1531
1532 #[test]
1533 fn test_cumsum_exclusive_reverse_cpu() {
1534 use crate::utils::ops::CumSumOp;
1535 use candle_core::Tensor;
1536 let device = candle_core::Device::Cpu;
1537 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1538 let b = a
1539 .fast_cumsum_config(0, false, true)
1540 .unwrap()
1541 .to_vec1::<i64>()
1542 .unwrap();
1543 assert_eq!(b, [9, 7, 4, 0]);
1544 }
1545
1546 #[test]
1547 fn test_cumsum_inclusive_reverse_cpu() {
1548 use crate::utils::ops::CumSumOp;
1549 use candle_core::Tensor;
1550 let device = candle_core::Device::Cpu;
1551 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1552 let b = a
1553 .fast_cumsum_config(0, true, true)
1554 .unwrap()
1555 .to_vec1::<i64>()
1556 .unwrap();
1557 assert_eq!(b, [10, 9, 7, 4]);
1558 }
1559
1560 #[cfg(feature = "metal")]
1561 #[test]
1562 fn test_cumsum_exclusive_forward_metal() {
1563 use crate::utils::ops::CumSumOp;
1564 use candle_core::Tensor;
1565 let device = candle_core::Device::new_metal(0).unwrap();
1566 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1567 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1568 assert_eq!(b, [0, 1, 3, 6]);
1569 }
1570
1571 #[cfg(feature = "metal")]
1572 #[test]
1573 fn test_cumsum_inclusive_forward_metal() {
1574 use crate::utils::ops::CumSumOp;
1575 use candle_core::Tensor;
1576 let device = candle_core::Device::new_metal(0).unwrap();
1577 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1578 let b = a
1579 .fast_cumsum_config(0, true, false)
1580 .unwrap()
1581 .to_vec1::<i64>()
1582 .unwrap();
1583 assert_eq!(b, [1, 3, 6, 10]);
1584 }
1585
1586 #[cfg(feature = "metal")]
1587 #[test]
1588 fn test_cumsum_exclusive_reverse_metal() {
1589 use crate::utils::ops::CumSumOp;
1590 use candle_core::Tensor;
1591 let device = candle_core::Device::new_metal(0).unwrap();
1592 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1593 let b = a
1594 .fast_cumsum_config(0, false, true)
1595 .unwrap()
1596 .to_vec1::<i64>()
1597 .unwrap();
1598 assert_eq!(b, [9, 7, 4, 0]);
1599 }
1600
1601 #[cfg(feature = "metal")]
1602 #[test]
1603 fn test_cumsum_inclusive_reverse_metal() {
1604 use crate::utils::ops::CumSumOp;
1605 use candle_core::Tensor;
1606 let device = candle_core::Device::new_metal(0).unwrap();
1607 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1608 let b = a
1609 .fast_cumsum_config(0, true, true)
1610 .unwrap()
1611 .to_vec1::<i64>()
1612 .unwrap();
1613 assert_eq!(b, [10, 9, 7, 4]);
1614 }
1615
1616 #[test]
1617 fn test_nonzero_cpu() {
1618 use crate::utils::ops::NonZeroOp;
1619 use candle_core::Tensor;
1620 let device = candle_core::Device::Cpu;
1621 let a = Tensor::from_vec(
1622 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1623 &[2, 4],
1624 &device,
1625 )
1626 .unwrap();
1627 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1628 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1629 }
1630
1631 #[cfg(feature = "cuda")]
1632 #[test]
1633 fn test_nonzero_cuda() {
1634 use crate::utils::ops::NonZeroOp;
1635 use candle_core::Tensor;
1636 let device = candle_core::Device::new_cuda(0).unwrap();
1637 let a = Tensor::from_vec(
1638 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1639 &[2, 4],
1640 &device,
1641 )
1642 .unwrap();
1643 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1644 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1645 }
1646
1647 #[test]
1648 fn test_bitwise_and_cpu() {
1649 use crate::utils::ops::BitWiseOp;
1650 use candle_core::Tensor;
1651 let device = candle_core::Device::Cpu;
1652 let a =
1653 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1654 let b =
1655 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1656 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1657 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
1658 }
1659
1660 #[cfg(feature = "cuda")]
1661 #[test]
1662 fn test_bitwise_and_cuda() {
1663 use crate::utils::ops::BitWiseOp;
1664 use candle_core::Tensor;
1665 let device = candle_core::Device::new_cuda(0).unwrap();
1666 let a =
1667 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1668 let b =
1669 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
1670 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1671 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
1672 }
1673
1674 #[test]
1675 fn test_bitwise_or_cpu() {
1676 use crate::utils::ops::BitWiseOp;
1677 use candle_core::Tensor;
1678 let device = candle_core::Device::Cpu;
1679 let a =
1680 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1681 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1682 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1683 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1684 }
1685
1686 #[cfg(feature = "cuda")]
1687 #[test]
1688 fn test_bitwise_or_cuda() {
1689 use crate::utils::ops::BitWiseOp;
1690 use candle_core::Tensor;
1691 let device = candle_core::Device::new_cuda(0).unwrap();
1692 let a =
1693 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1694 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1695 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1696 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1697 }
1698
1699 #[test]
1700 fn test_bitwise_xor_cpu() {
1701 use crate::utils::ops::BitWiseOp;
1702 use candle_core::Tensor;
1703 let device = candle_core::Device::Cpu;
1704 let a =
1705 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1706 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1707 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1708 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1709 }
1710
1711 #[cfg(feature = "cuda")]
1712 #[test]
1713 fn test_bitwise_xor_cuda() {
1714 use crate::utils::ops::BitWiseOp;
1715 use candle_core::Tensor;
1716 let device = candle_core::Device::new_cuda(0).unwrap();
1717 let a =
1718 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1719 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1720 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1721 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1722 }
1723
1724 #[test]
1725 fn test_nonzero_and() {
1726 use crate::utils::ops::{BitWiseOp, NonZeroOp};
1727 use candle_core::{Device, Tensor};
1728
1729 let input1 = Tensor::from_vec(
1730 vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1731 (10,),
1732 &Device::Cpu,
1733 )
1734 .unwrap();
1735 let input2 = Tensor::from_vec(
1736 vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1737 (10,),
1738 &Device::Cpu,
1739 )
1740 .unwrap();
1741 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1742
1743 let lt = input.lt(0.0).unwrap();
1744 let gt = input.gt(-10.0).unwrap();
1745 let res = lt
1746 .bitwise_and(>)
1747 .unwrap()
1748 .nonzero()
1749 .unwrap()
1750 .to_vec2::<u32>()
1751 .unwrap();
1752
1753 assert_eq!(
1754 res,
1755 [
1756 [0, 3],
1757 [0, 4],
1758 [0, 5],
1759 [0, 6],
1760 [1, 0],
1761 [1, 3],
1762 [1, 5],
1763 [1, 6]
1764 ]
1765 );
1766 }
1767
1768 #[cfg(feature = "cuda")]
1769 #[test]
1770 fn nonzero_and_cuda() {
1771 use crate::utils::ops::{BitWiseOp, NonZeroOp};
1772 use candle_core::{Device, Tensor};
1773
1774 let device = Device::new_cuda(0).unwrap();
1775 let input1 =
1776 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1777 let input2 =
1778 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1779 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1780
1781 let lt = input.lt(0.0).unwrap();
1782 let gt = input.gt(-10.0).unwrap();
1783 let res = lt
1784 .bitwise_and(>)
1785 .unwrap()
1786 .nonzero()
1787 .unwrap()
1788 .to_vec2::<u32>()
1789 .unwrap();
1790
1791 assert_eq!(
1792 res,
1793 [
1794 [0, 3],
1795 [0, 4],
1796 [0, 5],
1797 [0, 6],
1798 [1, 0],
1799 [1, 3],
1800 [1, 5],
1801 [1, 6]
1802 ]
1803 );
1804 }
1805
1806 #[test]
1807 fn test_bitpack_8bit_cpu() {
1808 use crate::HqqBits;
1809 use candle_core::{Device, Tensor};
1810 let bits = HqqBits::Eight;
1811 let device = Device::Cpu;
1812 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1813 let c = bits.bitpack_type()(wq.clone())
1814 .unwrap()
1815 .to_vec2::<u8>()
1816 .unwrap();
1817 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1818 }
1819
1820 #[cfg(feature = "cuda")]
1821 #[test]
1822 fn test_bitpack_8bit_cuda() {
1823 use crate::HqqBits;
1824 use candle_core::DType;
1825 use candle_core::{Device, Tensor};
1826 let bits = HqqBits::Eight;
1827 let device = Device::new_cuda(0).unwrap();
1828 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1829 let c = bits.bitpack_type()(wq.clone())
1830 .unwrap()
1831 .to_dtype(DType::U8)
1832 .unwrap()
1833 .to_vec2::<u8>()
1834 .unwrap();
1835 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1836 }
1837
1838 #[cfg(feature = "metal")]
1839 #[test]
1840 fn test_bitpack_8bit_metal() {
1841 use crate::HqqBits;
1842 use candle_core::{Device, Tensor};
1843 let bits = HqqBits::Eight;
1844 let device = Device::new_metal(0).unwrap();
1845 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1846 let c = bits.bitpack_type()(wq.clone())
1847 .unwrap()
1848 .to_vec2::<u8>()
1849 .unwrap();
1850 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1851 }
1852
1853 #[test]
1854 fn test_bitpack_4bit() {
1855 use crate::HqqBits;
1856 use candle_core::{Device, Tensor};
1857 let bits = HqqBits::Four;
1858 let device = Device::Cpu;
1859 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1860 let c = bits.bitpack_type()(wq.clone())
1861 .unwrap()
1862 .to_vec2::<u8>()
1863 .unwrap();
1864 assert_eq!(c, [[19, 36]]);
1865 }
1866
1867 #[cfg(feature = "cuda")]
1868 #[test]
1869 fn test_bitpack_4bit_cuda() {
1870 use crate::HqqBits;
1871 use candle_core::{Device, Tensor};
1872 let bits = HqqBits::Four;
1873 let device = Device::new_cuda(0).unwrap();
1874 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1875 let c = bits.bitpack_type()(wq.clone())
1876 .unwrap()
1877 .to_vec2::<u8>()
1878 .unwrap();
1879 assert_eq!(c, [[19, 36]]);
1880 }
1881
1882 #[cfg(feature = "metal")]
1883 #[test]
1884 fn test_bitpack_4bit_metal() {
1885 use crate::HqqBits;
1886 use candle_core::{Device, Tensor};
1887 let bits = HqqBits::Four;
1888 let device = Device::new_metal(0).unwrap();
1889 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1890 let c = bits.bitpack_type()(wq.clone())
1891 .unwrap()
1892 .to_vec2::<u8>()
1893 .unwrap();
1894 assert_eq!(c, [[19, 36]]);
1895 }
1896 #[cfg(feature = "metal")]
1898 #[test]
1899 fn test_sort_and_argsort_vector_metal() {
1900 use crate::utils::ops::SortOp;
1901 use candle_core::Tensor;
1902
1903 let device = candle_core::Device::new_metal(0).unwrap();
1904 let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
1905
1906 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1908 assert_eq!(sorted, [1, 2, 3, 4]);
1909
1910 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1912 assert_eq!(idx, [1, 3, 0, 2]);
1913 }
1914
1915 #[cfg(feature = "metal")]
1916 #[test]
1917 fn test_sort_and_argsort_matrix_axis1_metal() {
1918 use crate::utils::ops::SortOp;
1919 use candle_core::Tensor;
1920
1921 let device = candle_core::Device::new_metal(0).unwrap();
1922 let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
1926
1927 let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
1929 assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
1930
1931 let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
1933 assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
1934 }
1935
1936 #[cfg(feature = "metal")]
1938 #[test]
1939 fn test_sort_and_argsort_vector_2048_metal() {
1940 use crate::utils::ops::SortOp;
1941 use candle_core::Tensor;
1942
1943 const N: usize = 4096;
1944
1945 let device = candle_core::Device::new_metal(0).expect("Metal device");
1946
1947 let vals: Vec<i32> = (0..N as i32).rev().collect();
1949 let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
1950
1951 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1953 let expected: Vec<i32> = (0..N as i32).collect();
1954 assert_eq!(sorted, expected);
1955
1956 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1958 for (i, &v) in idx.iter().enumerate() {
1960 assert_eq!(v as usize, N - 1 - i);
1961 }
1962 }
1963}