1use candle_core::{
2 backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3 Result, Shape, Tensor, WithDType, D,
4};
5
6use std::{
7 fmt::Display,
8 ops::{BitAnd, BitOr, BitXor},
9};
10
11#[cfg(feature = "cuda")]
12use crate::cuda::ffi;
13#[cfg(feature = "cuda")]
14use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr};
15#[cfg(feature = "cuda")]
16use half::{bf16, f16};
17#[cfg(feature = "cuda")]
18use std::ffi::c_void;
19pub enum BitWiseOpEnum {
20 And,
21 Or,
22 Xor,
23}
24
25impl Display for BitWiseOpEnum {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 BitWiseOpEnum::And => write!(f, "And"),
29 BitWiseOpEnum::Or => write!(f, "Or"),
30 BitWiseOpEnum::Xor => write!(f, "Xor"),
31 }
32 }
33}
34
35struct BitWise {
36 pub op: BitWiseOpEnum,
37}
38
39impl BitWise {
40 pub fn new(op: BitWiseOpEnum) -> Self {
41 Self { op }
42 }
43
44 fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
45 &self,
46 vs1: &[T],
47 vs2: &[T],
48 ) -> Vec<T> {
49 let n = vs1.len();
50 let mut result = Vec::with_capacity(n);
51 for i in 0..n {
52 let v1 = vs1[i];
53 let v2 = vs2[i];
54 let r = match self.op {
55 BitWiseOpEnum::And => v1 & v2,
56 BitWiseOpEnum::Or => v1 | v2,
57 BitWiseOpEnum::Xor => v1 ^ v2,
58 };
59 result.push(r);
60 }
61 result
62 }
63}
64
65impl CustomOp2 for BitWise {
66 fn name(&self) -> &'static str {
67 "bitwise"
68 }
69
70 fn cpu_fwd(
71 &self,
72 s1: &CpuStorage,
73 l1: &Layout,
74 s2: &CpuStorage,
75 l2: &Layout,
76 ) -> Result<(CpuStorage, Shape)> {
77 if l1 != l2 {
78 return Err(Error::ShapeMismatchBinaryOp {
79 lhs: l1.shape().clone(),
80 rhs: l2.shape().clone(),
81 op: "bitwise",
82 });
83 }
84 if s1.dtype() != s2.dtype() {
85 return Err(Error::DTypeMismatchBinaryOp {
86 lhs: s1.dtype(),
87 rhs: s2.dtype(),
88 op: "bitwise",
89 });
90 }
91 match s1 {
92 CpuStorage::U8(vs1) => {
93 let vs2 = s2.as_slice::<u8>().unwrap();
94 let result = self.bitwise(vs1, vs2);
95 let result = CpuStorage::U8(result);
96 Ok((result, l1.shape().clone()))
97 }
98 CpuStorage::U32(vs1) => {
99 let vs2 = s2.as_slice::<u32>().unwrap();
100 let result = self.bitwise(vs1, vs2);
101 let result = CpuStorage::U32(result);
102 Ok((result, l1.shape().clone()))
103 }
104 CpuStorage::I64(vs1) => {
105 let vs2 = s2.as_slice::<i64>().unwrap();
106 let result = self.bitwise(vs1, vs2);
107 let result = CpuStorage::I64(result);
108 Ok((result, l1.shape().clone()))
109 }
110 CpuStorage::I16(vs1) => {
111 let vs2 = s2.as_slice::<i16>().unwrap();
112 let result = self.bitwise(vs1, vs2);
113 let result = CpuStorage::I16(result);
114 Ok((result, l1.shape().clone()))
115 }
116 CpuStorage::I32(vs1) => {
117 let vs2 = s2.as_slice::<i32>().unwrap();
118 let result = self.bitwise(vs1, vs2);
119 let result = CpuStorage::I32(result);
120 Ok((result, l1.shape().clone()))
121 }
122 CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")),
123 CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")),
124 CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")),
125 CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")),
126 CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")),
127 }
128 }
129 #[cfg(feature = "cuda")]
130 fn cuda_fwd(
131 &self,
132 s1: &CudaStorage,
133 l1: &Layout,
134 s2: &CudaStorage,
135 l2: &Layout,
136 ) -> Result<(CudaStorage, Shape)> {
137 if l1 != l2 {
138 return Err(Error::ShapeMismatchBinaryOp {
139 lhs: l1.shape().clone(),
140 rhs: l2.shape().clone(),
141 op: "bitwise",
142 });
143 }
144 if s1.dtype() != s2.dtype() {
145 return Err(Error::DTypeMismatchBinaryOp {
146 lhs: s1.dtype(),
147 rhs: s2.dtype(),
148 op: "bitwise",
149 });
150 }
151 let dev = s1.device().clone();
152 let (d_in1_ptr, d_in2_ptr, elem_count) = match s1.dtype() {
153 DType::U8 => {
154 let d_in1_ptr = *s1.as_cuda_slice::<u8>()?.device_ptr() as *const c_void;
155 let d_in2_ptr = *s2.as_cuda_slice::<u8>()?.device_ptr() as *const c_void;
156 let elem_count = l1.shape().elem_count();
157 (d_in1_ptr, d_in2_ptr, elem_count)
158 }
159 DType::U32 => {
160 let d_in1_ptr = *s1.as_cuda_slice::<u32>()?.device_ptr() as *const c_void;
161 let d_in2_ptr = *s2.as_cuda_slice::<u32>()?.device_ptr() as *const c_void;
162 let elem_count = l1.shape().elem_count();
163 (d_in1_ptr, d_in2_ptr, elem_count)
164 }
165 DType::I64 => {
166 let d_in1_ptr = *s1.as_cuda_slice::<i64>()?.device_ptr() as *const c_void;
167 let d_in2_ptr = *s2.as_cuda_slice::<i64>()?.device_ptr() as *const c_void;
168 let elem_count = l1.shape().elem_count();
169 (d_in1_ptr, d_in2_ptr, elem_count)
170 }
171 DType::I32 => {
172 let d_in1_ptr = *s1.as_cuda_slice::<i32>()?.device_ptr() as *const c_void;
173 let d_in2_ptr = *s2.as_cuda_slice::<i32>()?.device_ptr() as *const c_void;
174 let elem_count = l1.shape().elem_count();
175 (d_in1_ptr, d_in2_ptr, elem_count)
176 }
177 DType::I16 => {
178 let d_in1_ptr = *s1.as_cuda_slice::<i16>()?.device_ptr() as *const c_void;
179 let d_in2_ptr = *s2.as_cuda_slice::<i16>()?.device_ptr() as *const c_void;
180 let elem_count = l1.shape().elem_count();
181 (d_in1_ptr, d_in2_ptr, elem_count)
182 }
183 DType::BF16 => {
184 return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise"));
185 }
186 DType::F16 => {
187 return Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise"));
188 }
189 DType::F32 => {
190 return Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise"));
191 }
192 DType::F64 => {
193 return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise"));
194 }
195 DType::F8E4M3 => {
196 return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise"));
197 }
198 };
199 let dst = match s1.dtype() {
200 DType::U8 => {
201 let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
202 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
203 unsafe {
204 match self.op {
205 BitWiseOpEnum::And => ffi::bitwise_and_u8(
206 d_in1_ptr,
207 d_in2_ptr,
208 d_out_ptr,
209 u32::try_from(elem_count)?,
210 ),
211 BitWiseOpEnum::Or => ffi::bitwise_or_u8(
212 d_in1_ptr,
213 d_in2_ptr,
214 d_out_ptr,
215 u32::try_from(elem_count)?,
216 ),
217 BitWiseOpEnum::Xor => ffi::bitwise_xor_u8(
218 d_in1_ptr,
219 d_in2_ptr,
220 d_out_ptr,
221 u32::try_from(elem_count)?,
222 ),
223 }
224 };
225 CudaStorage::wrap_cuda_slice(d_out, dev)
226 }
227 DType::U32 => {
228 let d_out = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
229 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
230 unsafe {
231 match self.op {
232 BitWiseOpEnum::And => ffi::bitwise_and_u32(
233 d_in1_ptr,
234 d_in2_ptr,
235 d_out_ptr,
236 u32::try_from(elem_count)?,
237 ),
238 BitWiseOpEnum::Or => ffi::bitwise_or_u32(
239 d_in1_ptr,
240 d_in2_ptr,
241 d_out_ptr,
242 u32::try_from(elem_count)?,
243 ),
244 BitWiseOpEnum::Xor => ffi::bitwise_xor_u32(
245 d_in1_ptr,
246 d_in2_ptr,
247 d_out_ptr,
248 u32::try_from(elem_count)?,
249 ),
250 }
251 };
252 CudaStorage::wrap_cuda_slice(d_out, dev)
253 }
254 DType::I64 => {
255 let d_out = unsafe { dev.alloc::<i64>(elem_count) }.w()?;
256 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
257 unsafe {
258 match self.op {
259 BitWiseOpEnum::And => ffi::bitwise_and_i64(
260 d_in1_ptr,
261 d_in2_ptr,
262 d_out_ptr,
263 u32::try_from(elem_count)?,
264 ),
265 BitWiseOpEnum::Or => ffi::bitwise_or_i64(
266 d_in1_ptr,
267 d_in2_ptr,
268 d_out_ptr,
269 u32::try_from(elem_count)?,
270 ),
271 BitWiseOpEnum::Xor => ffi::bitwise_xor_i64(
272 d_in1_ptr,
273 d_in2_ptr,
274 d_out_ptr,
275 u32::try_from(elem_count)?,
276 ),
277 }
278 };
279 CudaStorage::wrap_cuda_slice(d_out, dev)
280 }
281 DType::I32 => {
282 let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
283 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
284 unsafe {
285 match self.op {
286 BitWiseOpEnum::And => ffi::bitwise_and_i32(
287 d_in1_ptr,
288 d_in2_ptr,
289 d_out_ptr,
290 u32::try_from(elem_count)?,
291 ),
292 BitWiseOpEnum::Or => ffi::bitwise_or_i32(
293 d_in1_ptr,
294 d_in2_ptr,
295 d_out_ptr,
296 u32::try_from(elem_count)?,
297 ),
298 BitWiseOpEnum::Xor => ffi::bitwise_xor_i32(
299 d_in1_ptr,
300 d_in2_ptr,
301 d_out_ptr,
302 u32::try_from(elem_count)?,
303 ),
304 }
305 };
306 CudaStorage::wrap_cuda_slice(d_out, dev)
307 }
308 _ => unreachable!(),
309 };
310 Ok((dst, l1.shape().clone()))
311 }
312}
313
314#[allow(dead_code)]
315pub trait BitWiseOp {
316 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
317 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
318 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
319}
320
321impl BitWiseOp for Tensor {
322 #[cfg(feature = "metal")]
323 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
324 let original_device = rhs.device();
325 self.to_device(&candle_core::Device::Cpu)?
326 .apply_op2_no_bwd(
327 &rhs.to_device(&candle_core::Device::Cpu)?,
328 &BitWise::new(BitWiseOpEnum::And),
329 )?
330 .to_device(original_device)
331 }
332 #[cfg(not(feature = "metal"))]
333 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
334 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::And))
335 }
336
337 #[cfg(feature = "metal")]
338 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
339 let original_device = rhs.device();
340 self.to_device(&candle_core::Device::Cpu)?
341 .apply_op2_no_bwd(
342 &rhs.to_device(&candle_core::Device::Cpu)?,
343 &BitWise::new(BitWiseOpEnum::Or),
344 )?
345 .to_device(original_device)
346 }
347 #[cfg(not(feature = "metal"))]
348 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
349 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::Or))
350 }
351
352 #[cfg(feature = "metal")]
353 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
354 let original_device = rhs.device();
355 self.to_device(&candle_core::Device::Cpu)?
356 .apply_op2_no_bwd(
357 &rhs.to_device(&candle_core::Device::Cpu)?,
358 &BitWise::new(BitWiseOpEnum::Xor),
359 )?
360 .to_device(original_device)
361 }
362 #[cfg(not(feature = "metal"))]
363 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
364 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::Xor))
365 }
366}
367
368struct NonZero {}
369impl NonZero {
370 fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
373 let n = layout.dims().len();
374 let mut result = Vec::new();
375 let mut indices = vec![0u32; n];
376 for (i, v) in vs.iter().enumerate() {
377 if !v.is_zero() {
378 let mut idx = i;
379 for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
380 let d = idx % dim;
381 indices[dim_index] = u32::try_from(d).unwrap();
382 idx /= dim;
383 }
384 result.extend_from_slice(&indices);
385 }
386 }
387 result
388 }
389}
390
391#[cfg(feature = "cuda")]
392fn count_nonzero_cuda(
393 dtype: candle_core::DType,
394 d_in: *const c_void,
395 n: u32,
396 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
397) -> u32 {
398 unsafe {
399 match dtype {
400 candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
401 candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
402 candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
403 candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
404 candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
405 candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
406 candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
407 candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
408 candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
409 candle_core::DType::F8E4M3 => todo!(),
410 }
411 }
412}
413
414#[allow(clippy::too_many_arguments)]
415#[cfg(feature = "cuda")]
416fn nonzero_cuda(
417 dtype: candle_core::DType,
418 d_in: *const c_void,
419 n: u32,
420 num_nonzero: u32,
421 dims: *const c_void,
422 num_dims: u32,
423 d_out: *mut c_void,
424 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
425) {
426 unsafe {
427 match dtype {
428 candle_core::DType::U8 => {
429 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
430 }
431 candle_core::DType::U32 => {
432 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
433 }
434 candle_core::DType::I64 => {
435 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
436 }
437 candle_core::DType::I32 => {
438 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
439 }
440 candle_core::DType::I16 => {
441 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
442 }
443 candle_core::DType::BF16 => {
444 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
445 }
446 candle_core::DType::F16 => {
447 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
448 }
449 candle_core::DType::F32 => {
450 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
451 }
452 candle_core::DType::F64 => {
453 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
454 }
455 candle_core::DType::F8E4M3 => todo!(),
456 }
457 }
458}
459
460impl CustomOp1 for NonZero {
461 fn name(&self) -> &'static str {
462 "nonzero"
463 }
464
465 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
466 if !layout.is_contiguous() {
467 return Err(Error::RequiresContiguous { op: "nonzero" });
468 }
469 let result = match storage {
470 candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
471 candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
472 candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
473 candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
474 candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
475 candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
476 candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
477 candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
478 candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
479 candle_core::CpuStorage::F8E4M3(_vs) => todo!(),
480 };
481 let index_len = layout.dims().len();
482 let result_len = result.len() / index_len;
483 let result = CpuStorage::U32(result);
484 let shape = Shape::from_dims(&[result_len, index_len]);
485 Ok((result, shape))
486 }
487 #[cfg(feature = "cuda")]
488 fn cuda_fwd(
489 &self,
490 storage: &candle_core::CudaStorage,
491 layout: &Layout,
492 ) -> Result<(candle_core::CudaStorage, Shape)> {
493 if !layout.is_contiguous() {
494 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
495 }
496 let dev = storage.device().clone();
497 let d_in = match storage.dtype() {
498 candle_core::DType::U8 => *storage.as_cuda_slice::<u8>()?.device_ptr(),
499 candle_core::DType::U32 => *storage.as_cuda_slice::<u32>()?.device_ptr(),
500 candle_core::DType::I32 => *storage.as_cuda_slice::<i32>()?.device_ptr(),
501 candle_core::DType::I16 => *storage.as_cuda_slice::<i16>()?.device_ptr(),
502 candle_core::DType::I64 => *storage.as_cuda_slice::<i64>()?.device_ptr(),
503 candle_core::DType::BF16 => *storage.as_cuda_slice::<bf16>()?.device_ptr(),
504 candle_core::DType::F16 => *storage.as_cuda_slice::<f16>()?.device_ptr(),
505 candle_core::DType::F32 => *storage.as_cuda_slice::<f32>()?.device_ptr(),
506 candle_core::DType::F64 => *storage.as_cuda_slice::<f64>()?.device_ptr(),
507 candle_core::DType::F8E4M3 => todo!(),
508 } as *const c_void;
509 let n = layout.shape().elem_count();
510
511 let num_nonzero =
512 count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?, *dev.cu_stream());
513 let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
514 .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
515 let d_out_ptr = *d_out.device_ptr() as *mut c_void;
516 let dims = layout
517 .dims()
518 .iter()
519 .map(|&x| u32::try_from(x).unwrap())
520 .collect::<Vec<u32>>();
521 let d_dims = dev
522 .htod_copy(dims)
523 .map_err(|_| Error::Msg("Failed to copy dims to device".to_string()))?;
524 let d_dims_ptr = *d_dims.device_ptr() as *const c_void;
525 nonzero_cuda(
526 storage.dtype(),
527 d_in,
528 u32::try_from(n)?,
529 num_nonzero,
530 d_dims_ptr,
531 u32::try_from(layout.dims().len())?,
532 d_out_ptr,
533 *dev.cu_stream(),
534 );
535 let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
536 let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
537 Ok((dst, shape))
538 }
539}
540
541pub trait NonZeroOp {
542 fn nonzero(&self) -> Result<Tensor>;
543}
544
545impl NonZeroOp for Tensor {
546 #[cfg(feature = "metal")]
547 fn nonzero(&self) -> Result<Tensor> {
548 if !self.is_contiguous() {
549 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
550 }
551 let original_device = self.device();
552 self.to_device(&candle_core::Device::Cpu)?
553 .apply_op1_no_bwd(&NonZero {})?
554 .to_device(original_device)
555 }
556 #[cfg(not(feature = "metal"))]
557 fn nonzero(&self) -> Result<Tensor> {
558 if !self.is_contiguous() {
559 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
560 }
561 self.apply_op1_no_bwd(&NonZero {})
562 }
563}
564
565#[allow(dead_code)]
566#[derive(Debug, Clone)]
567struct ArgSort {
568 asc: bool,
569 last_dim: usize,
570 inplace: bool,
571}
572
573impl candle_core::CustomOp1 for ArgSort {
574 fn name(&self) -> &'static str {
575 "argsort"
576 }
577
578 fn cpu_fwd(
579 &self,
580 _: &candle_core::CpuStorage,
581 _: &candle_core::Layout,
582 ) -> Result<(candle_core::CpuStorage, candle_core::Shape)> {
583 panic!("not implemented!")
584 }
585
586 #[allow(clippy::cast_possible_truncation)]
587 #[cfg(feature = "cuda")]
588 fn cuda_fwd(
589 &self,
590 storage: &candle_core::CudaStorage,
591 layout: &candle_core::Layout,
592 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
593 use candle_core::backend::BackendStorage;
594 use candle_core::cuda_backend::cudarc::driver::DevicePtr;
595 use candle_core::cuda_backend::CudaStorageSlice;
596 use candle_core::cuda_backend::WrapErr;
597 let dev = storage.device();
598 let elem_count = layout.shape().elem_count();
599 let ncols = self.last_dim as i32;
600 let nrows = elem_count as i32 / ncols;
601 let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
602
603 use std::ffi::c_void;
604
605 let src = match &storage.slice {
606 CudaStorageSlice::U8(inp) => inp.device_ptr(),
607 CudaStorageSlice::U32(inp) => inp.device_ptr(),
608 CudaStorageSlice::I64(inp) => inp.device_ptr(),
609 CudaStorageSlice::BF16(inp) => inp.device_ptr(),
610 CudaStorageSlice::F16(inp) => inp.device_ptr(),
611 CudaStorageSlice::F32(inp) => inp.device_ptr(),
612 CudaStorageSlice::F64(inp) => inp.device_ptr(),
613 _ => candle_core::bail!("Unexpected dtype in asort"),
614 };
615 let src_ptr = *src as *const c_void;
616 let dst_ptr = *dst.device_ptr() as *mut c_void;
617 let stream = *dev.cu_stream() as i64;
618 unsafe {
619 if self.asc {
620 match storage.dtype() {
621 candle_core::DType::U8 => {
622 ffi::asort_asc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
623 }
624 candle_core::DType::U32 => {
625 ffi::asort_asc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
626 }
627 candle_core::DType::I64 => {
628 ffi::asort_asc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
629 }
630 candle_core::DType::BF16 => {
631 ffi::asort_asc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
632 }
633 candle_core::DType::F16 => {
634 ffi::asort_asc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
635 }
636 candle_core::DType::F32 => {
637 ffi::asort_asc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
638 }
639 candle_core::DType::F64 => {
640 ffi::asort_asc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
641 }
642 _ => candle_core::bail!("Unexpected dtype in asort"),
643 }
644 } else {
645 match storage.dtype() {
646 candle_core::DType::U8 => {
647 ffi::asort_desc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
648 }
649 candle_core::DType::U32 => {
650 ffi::asort_desc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
651 }
652 candle_core::DType::I64 => {
653 ffi::asort_desc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
654 }
655 candle_core::DType::BF16 => {
656 ffi::asort_desc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
657 }
658 candle_core::DType::F16 => {
659 ffi::asort_desc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
660 }
661 candle_core::DType::F32 => {
662 ffi::asort_desc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
663 }
664 candle_core::DType::F64 => {
665 ffi::asort_desc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
666 }
667 _ => candle_core::bail!("Unexpected dtype in asort"),
668 }
669 }
670 }
671 let dst_ret = candle_core::cuda_backend::CudaStorage {
672 slice: CudaStorageSlice::U32(dst),
673 device: dev.clone(),
674 };
675 Ok((dst_ret, layout.shape().clone()))
676 }
677}
678
679#[allow(dead_code)]
680pub trait ArgSortOp {
681 fn arg_sort(&self, asc: bool) -> Result<Tensor>;
682 fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)>;
683}
684
685impl ArgSortOp for Tensor {
686 fn arg_sort(&self, asc: bool) -> Result<Tensor> {
692 if !self.is_contiguous() {
693 return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
694 }
695 let last_dim = match self.dims().last() {
696 Some(last_dim) => *last_dim,
697 None => candle_core::bail!("empty last-dim in arg-sort"),
698 };
699 self.apply_op1_no_bwd(&ArgSort {
701 asc,
702 last_dim,
703 inplace: false,
704 })
705 }
706
707 fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)> {
714 if !self.is_contiguous() {
715 return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
716 }
717 let last_dim = match self.dims().last() {
718 Some(last_dim) => *last_dim,
719 None => candle_core::bail!("empty last-dim in arg-sort"),
720 };
721 let sorted = self.copy()?;
722
723 let asort = sorted.apply_op1_no_bwd(&ArgSort {
724 asc,
725 last_dim,
726 inplace: true,
727 })?;
728
729 Ok((sorted, asort))
730 }
731}
732
733#[allow(dead_code)]
734pub struct TopKOutput {
735 pub values: Tensor,
736 pub indices: Tensor,
737}
738
739pub trait TopKLastDimOp {
740 fn topk(&self, topk: usize) -> Result<TopKOutput>;
744
745 fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
749}
750
751impl TopKLastDimOp for Tensor {
752 fn topk(&self, topk: usize) -> Result<TopKOutput> {
753 #[cfg(feature = "cuda")]
755 let (values, sorted_indices) = self.sort(false)?;
756 #[cfg(not(feature = "cuda"))]
757 let (values, sorted_indices) = self.sort_last_dim(false)?;
758 let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
759 let topk_values = values.narrow(D::Minus1, 0, topk)?.contiguous()?;
760 Ok(TopKOutput {
761 values: topk_values,
762 indices: topk_indices,
763 })
764 }
765
766 fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
767 let TopKOutput { values, indices } = self.topk(topk)?;
769 #[cfg(feature = "cuda")]
771 let reorder_indices = indices.arg_sort(true)?;
772 #[cfg(not(feature = "cuda"))]
773 let reorder_indices = indices.arg_sort_last_dim(true)?;
774 let topk_indices_unsorted = indices
775 .to_dtype(DType::F32)?
776 .gather(&reorder_indices, D::Minus1)?
777 .to_dtype(DType::U32)?;
778 let topk_values_unsorted = values.gather(&reorder_indices, D::Minus1)?;
779 Ok(TopKOutput {
780 values: topk_values_unsorted,
781 indices: topk_indices_unsorted,
782 })
783 }
784}
785
786pub trait RepeatInterleaveOp {
787 fn repeat_interleave(&self, repeats: usize, dim: usize) -> Result<Tensor>;
788 fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor>;
789}
790
791impl RepeatInterleaveOp for Tensor {
792 fn repeat_interleave(&self, repeats: usize, dim: usize) -> Result<Tensor> {
793 assert!(self.dtype().is_float());
795 #[allow(clippy::cast_possible_truncation)]
796 let indices = Tensor::new(
797 (0..self.dim(dim)?)
798 .flat_map(|i| vec![i as u32; repeats])
799 .collect::<Vec<_>>(),
800 self.device(),
801 )?;
802 self.index_select(&indices, dim)
803 }
804
805 fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor> {
806 let xs = self.flatten_all()?;
807 if repeats.len() != xs.dim(0)? {
808 candle_core::bail!(
809 "repeats ({}) must match flattened self length ({})",
810 repeats.len(),
811 xs.dim(0)?
812 );
813 }
814 #[allow(clippy::cast_possible_truncation)]
815 let indices = Tensor::new(
816 (0..xs.dim(0)?)
817 .flat_map(|i| vec![i as u32; repeats[i] as usize])
818 .collect::<Vec<_>>(),
819 xs.device(),
820 )?;
821 xs.index_select(&indices, 0)
822 }
823}
824
825pub trait SplitOp {
826 fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
827}
828
829impl SplitOp for Tensor {
830 fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
831 let dim = dim.to_index(self.shape(), "split")?;
832 let mut split_res = Vec::new();
833 let mut index = 0;
834 for split in splits {
835 split_res.push(self.narrow(dim, index, *split)?);
836 index += *split;
837 }
838 Ok(split_res)
839 }
840}
841
842pub trait BincountOp {
843 fn bincount(&self, minlength: u32) -> Result<Vec<u32>>;
844}
845
846fn bincount(values: &[u32], minlength: u32) -> Vec<u32> {
847 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
860
861 if values.is_empty() {
863 return vec![0u32; minlength as usize];
864 }
865
866 let max_val = *values.par_iter().max().unwrap();
869
870 let result_len = (max_val + 1).max(minlength) as usize;
872
873 values
876 .par_iter()
877 .fold(
878 || vec![0u32; result_len],
879 |mut local_hist, &v| {
880 unsafe {
882 *local_hist.get_unchecked_mut(v as usize) += 1;
883 }
884 local_hist
885 },
886 )
887 .reduce(
889 || vec![0u32; result_len],
890 |mut global_hist, local_hist| {
891 for i in 0..result_len {
892 unsafe {
894 *global_hist.get_unchecked_mut(i) += local_hist.get_unchecked(i);
895 }
896 }
897 global_hist
898 },
899 )
900}
901
902impl BincountOp for Tensor {
903 fn bincount(&self, minlength: u32) -> Result<Vec<u32>> {
904 let values = self.to_vec1::<u32>()?;
905
906 Ok(bincount(&values, minlength))
907 }
908}
909
910mod tests {
911 #[test]
912 fn test_topk() {
913 use crate::ops::{TopKLastDimOp, TopKOutput};
914 use candle_core::Tensor;
915 let device = candle_core::Device::Cpu;
916 let x = Tensor::arange(1f32, 7f32, &device)
919 .unwrap()
920 .reshape((3, 2))
921 .unwrap()
922 .t()
923 .unwrap()
924 .contiguous()
925 .unwrap();
926 let TopKOutput { values, indices } = x.topk(2).unwrap();
927 assert_eq!(
928 x.to_vec2::<f32>().unwrap(),
929 vec![vec![1f32, 3f32, 5f32], vec![2f32, 4f32, 6f32]]
930 );
931 assert_eq!(
932 values.to_vec2::<f32>().unwrap(),
933 vec![vec![5f32, 3f32], vec![6f32, 4f32]]
934 );
935 assert_eq!(
936 indices.to_vec2::<u32>().unwrap(),
937 vec![vec![2u32, 1u32], vec![2u32, 1u32]]
938 );
939 }
940
941 #[test]
942 fn test_nonzero_cpu() {
943 use crate::ops::NonZeroOp;
944 use candle_core::Tensor;
945 let device = candle_core::Device::Cpu;
946 let a = Tensor::from_vec(
947 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
948 &[2, 4],
949 &device,
950 )
951 .unwrap();
952 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
953 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
954 }
955
956 #[cfg(feature = "cuda")]
957 #[test]
958 fn test_nonzero_cuda() {
959 use crate::ops::NonZeroOp;
960 use candle_core::Tensor;
961 let device = candle_core::Device::new_cuda(0).unwrap();
962 let a = Tensor::from_vec(
963 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
964 &[2, 4],
965 &device,
966 )
967 .unwrap();
968 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
969 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
970 }
971
972 #[test]
973 fn test_bitwise_and_cpu() {
974 use crate::ops::BitWiseOp;
975 use candle_core::Tensor;
976 let device = candle_core::Device::Cpu;
977 let a =
978 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
979 let b =
980 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
981 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
982 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
983 }
984
985 #[cfg(feature = "cuda")]
986 #[test]
987 fn test_bitwise_and_cuda() {
988 use crate::ops::BitWiseOp;
989 use candle_core::Tensor;
990 let device = candle_core::Device::new_cuda(0).unwrap();
991 let a =
992 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
993 let b =
994 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
995 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
996 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
997 }
998
999 #[test]
1000 fn test_bitwise_or_cpu() {
1001 use crate::ops::BitWiseOp;
1002 use candle_core::Tensor;
1003 let device = candle_core::Device::Cpu;
1004 let a =
1005 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1006 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1007 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1008 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1009 }
1010
1011 #[cfg(feature = "cuda")]
1012 #[test]
1013 fn test_bitwise_or_cuda() {
1014 use crate::ops::BitWiseOp;
1015 use candle_core::Tensor;
1016 let device = candle_core::Device::new_cuda(0).unwrap();
1017 let a =
1018 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1019 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1020 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1021 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1022 }
1023
1024 #[test]
1025 fn test_bitwise_xor_cpu() {
1026 use crate::ops::BitWiseOp;
1027 use candle_core::Tensor;
1028 let device = candle_core::Device::Cpu;
1029 let a =
1030 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1031 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1032 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1033 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1034 }
1035
1036 #[cfg(feature = "cuda")]
1037 #[test]
1038 fn test_bitwise_xor_cuda() {
1039 use crate::ops::BitWiseOp;
1040 use candle_core::Tensor;
1041 let device = candle_core::Device::new_cuda(0).unwrap();
1042 let a =
1043 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1044 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1045 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1046 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1047 }
1048
1049 #[test]
1050 fn test_nonzero_and() {
1051 use crate::ops::{BitWiseOp, NonZeroOp};
1052 use candle_core::{Device, Tensor};
1053
1054 let input1 = Tensor::from_vec(
1055 vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1056 (10,),
1057 &Device::Cpu,
1058 )
1059 .unwrap();
1060 let input2 = Tensor::from_vec(
1061 vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1062 (10,),
1063 &Device::Cpu,
1064 )
1065 .unwrap();
1066 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1067
1068 let lt = input.lt(0.0).unwrap();
1069 let gt = input.gt(-10.0).unwrap();
1070 let res = lt
1071 .bitwise_and(>)
1072 .unwrap()
1073 .nonzero()
1074 .unwrap()
1075 .to_vec2::<u32>()
1076 .unwrap();
1077
1078 assert_eq!(
1079 res,
1080 [
1081 [0, 3],
1082 [0, 4],
1083 [0, 5],
1084 [0, 6],
1085 [1, 0],
1086 [1, 3],
1087 [1, 5],
1088 [1, 6]
1089 ]
1090 );
1091 }
1092
1093 #[cfg(feature = "cuda")]
1094 #[test]
1095 fn nonzero_and_cuda() {
1096 use crate::ops::{BitWiseOp, NonZeroOp};
1097 use candle_core::{Device, Tensor};
1098
1099 let device = Device::new_cuda(0).unwrap();
1100 let input1 =
1101 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1102 let input2 =
1103 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1104 let input = Tensor::stack(&[input1, input2], 0).unwrap();
1105
1106 let lt = input.lt(0.0).unwrap();
1107 let gt = input.gt(-10.0).unwrap();
1108 let res = lt
1109 .bitwise_and(>)
1110 .unwrap()
1111 .nonzero()
1112 .unwrap()
1113 .to_vec2::<u32>()
1114 .unwrap();
1115
1116 assert_eq!(
1117 res,
1118 [
1119 [0, 3],
1120 [0, 4],
1121 [0, 5],
1122 [0, 6],
1123 [1, 0],
1124 [1, 3],
1125 [1, 5],
1126 [1, 6]
1127 ]
1128 );
1129 }
1130
1131 #[test]
1132 fn test_repeat_interleave() -> candle_core::Result<()> {
1133 use crate::ops::RepeatInterleaveOp;
1134 use candle_core::{Device, Tensor};
1135
1136 let input = Tensor::new(
1137 vec![vec![vec![1f32, 2., 3.], vec![4f32, 5., 6.]]],
1138 &Device::Cpu,
1139 )?;
1140
1141 let repeat_interleaved = input.repeat_interleave(2, 2)?;
1142 assert_eq!(
1143 repeat_interleaved.to_vec3::<f32>()?,
1144 vec![vec![
1145 vec![1., 1., 2., 2., 3., 3.],
1146 vec![4., 4., 5., 5., 6., 6.]
1147 ]]
1148 );
1149
1150 Ok(())
1151 }
1152
1153 #[test]
1154 fn test_repeat_interleave_flat() -> candle_core::Result<()> {
1155 use crate::ops::RepeatInterleaveOp;
1156 use candle_core::{Device, Tensor};
1157
1158 let input = Tensor::new(vec![1., 2., 3., 4.], &Device::Cpu)?;
1159
1160 let repeat_interleaved = input.repeat_interleave_flat(vec![1u32, 2u32, 3u32, 4u32])?;
1161 assert_eq!(
1162 repeat_interleaved.to_vec1::<f64>()?,
1163 vec![1., 2., 2., 3., 3., 3., 4., 4., 4., 4.]
1164 );
1165
1166 Ok(())
1167 }
1168}