mistralrs_core/
ops.rs

1use candle_core::{
2    backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3    Result, Shape, Tensor, WithDType, D,
4};
5
6use std::{
7    fmt::Display,
8    ops::{BitAnd, BitOr, BitXor},
9};
10
11#[cfg(feature = "cuda")]
12use crate::cuda::ffi;
13#[cfg(feature = "cuda")]
14use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr};
15#[cfg(feature = "cuda")]
16use half::{bf16, f16};
17#[cfg(feature = "cuda")]
18use std::ffi::c_void;
19pub enum BitWiseOpEnum {
20    And,
21    Or,
22    Xor,
23}
24
25impl Display for BitWiseOpEnum {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            BitWiseOpEnum::And => write!(f, "And"),
29            BitWiseOpEnum::Or => write!(f, "Or"),
30            BitWiseOpEnum::Xor => write!(f, "Xor"),
31        }
32    }
33}
34
35struct BitWise {
36    pub op: BitWiseOpEnum,
37}
38
39impl BitWise {
40    pub fn new(op: BitWiseOpEnum) -> Self {
41        Self { op }
42    }
43
44    fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
45        &self,
46        vs1: &[T],
47        vs2: &[T],
48    ) -> Vec<T> {
49        let n = vs1.len();
50        let mut result = Vec::with_capacity(n);
51        for i in 0..n {
52            let v1 = vs1[i];
53            let v2 = vs2[i];
54            let r = match self.op {
55                BitWiseOpEnum::And => v1 & v2,
56                BitWiseOpEnum::Or => v1 | v2,
57                BitWiseOpEnum::Xor => v1 ^ v2,
58            };
59            result.push(r);
60        }
61        result
62    }
63}
64
65impl CustomOp2 for BitWise {
66    fn name(&self) -> &'static str {
67        "bitwise"
68    }
69
70    fn cpu_fwd(
71        &self,
72        s1: &CpuStorage,
73        l1: &Layout,
74        s2: &CpuStorage,
75        l2: &Layout,
76    ) -> Result<(CpuStorage, Shape)> {
77        if l1 != l2 {
78            return Err(Error::ShapeMismatchBinaryOp {
79                lhs: l1.shape().clone(),
80                rhs: l2.shape().clone(),
81                op: "bitwise",
82            });
83        }
84        if s1.dtype() != s2.dtype() {
85            return Err(Error::DTypeMismatchBinaryOp {
86                lhs: s1.dtype(),
87                rhs: s2.dtype(),
88                op: "bitwise",
89            });
90        }
91        match s1 {
92            CpuStorage::U8(vs1) => {
93                let vs2 = s2.as_slice::<u8>().unwrap();
94                let result = self.bitwise(vs1, vs2);
95                let result = CpuStorage::U8(result);
96                Ok((result, l1.shape().clone()))
97            }
98            CpuStorage::U32(vs1) => {
99                let vs2 = s2.as_slice::<u32>().unwrap();
100                let result = self.bitwise(vs1, vs2);
101                let result = CpuStorage::U32(result);
102                Ok((result, l1.shape().clone()))
103            }
104            CpuStorage::I64(vs1) => {
105                let vs2 = s2.as_slice::<i64>().unwrap();
106                let result = self.bitwise(vs1, vs2);
107                let result = CpuStorage::I64(result);
108                Ok((result, l1.shape().clone()))
109            }
110            CpuStorage::I16(vs1) => {
111                let vs2 = s2.as_slice::<i16>().unwrap();
112                let result = self.bitwise(vs1, vs2);
113                let result = CpuStorage::I16(result);
114                Ok((result, l1.shape().clone()))
115            }
116            CpuStorage::I32(vs1) => {
117                let vs2 = s2.as_slice::<i32>().unwrap();
118                let result = self.bitwise(vs1, vs2);
119                let result = CpuStorage::I32(result);
120                Ok((result, l1.shape().clone()))
121            }
122            CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")),
123            CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")),
124            CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")),
125            CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")),
126            CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")),
127        }
128    }
129    #[cfg(feature = "cuda")]
130    fn cuda_fwd(
131        &self,
132        s1: &CudaStorage,
133        l1: &Layout,
134        s2: &CudaStorage,
135        l2: &Layout,
136    ) -> Result<(CudaStorage, Shape)> {
137        if l1 != l2 {
138            return Err(Error::ShapeMismatchBinaryOp {
139                lhs: l1.shape().clone(),
140                rhs: l2.shape().clone(),
141                op: "bitwise",
142            });
143        }
144        if s1.dtype() != s2.dtype() {
145            return Err(Error::DTypeMismatchBinaryOp {
146                lhs: s1.dtype(),
147                rhs: s2.dtype(),
148                op: "bitwise",
149            });
150        }
151        let dev = s1.device().clone();
152        let (d_in1_ptr, d_in2_ptr, elem_count) = match s1.dtype() {
153            DType::U8 => {
154                let d_in1_ptr = *s1.as_cuda_slice::<u8>()?.device_ptr() as *const c_void;
155                let d_in2_ptr = *s2.as_cuda_slice::<u8>()?.device_ptr() as *const c_void;
156                let elem_count = l1.shape().elem_count();
157                (d_in1_ptr, d_in2_ptr, elem_count)
158            }
159            DType::U32 => {
160                let d_in1_ptr = *s1.as_cuda_slice::<u32>()?.device_ptr() as *const c_void;
161                let d_in2_ptr = *s2.as_cuda_slice::<u32>()?.device_ptr() as *const c_void;
162                let elem_count = l1.shape().elem_count();
163                (d_in1_ptr, d_in2_ptr, elem_count)
164            }
165            DType::I64 => {
166                let d_in1_ptr = *s1.as_cuda_slice::<i64>()?.device_ptr() as *const c_void;
167                let d_in2_ptr = *s2.as_cuda_slice::<i64>()?.device_ptr() as *const c_void;
168                let elem_count = l1.shape().elem_count();
169                (d_in1_ptr, d_in2_ptr, elem_count)
170            }
171            DType::I32 => {
172                let d_in1_ptr = *s1.as_cuda_slice::<i32>()?.device_ptr() as *const c_void;
173                let d_in2_ptr = *s2.as_cuda_slice::<i32>()?.device_ptr() as *const c_void;
174                let elem_count = l1.shape().elem_count();
175                (d_in1_ptr, d_in2_ptr, elem_count)
176            }
177            DType::I16 => {
178                let d_in1_ptr = *s1.as_cuda_slice::<i16>()?.device_ptr() as *const c_void;
179                let d_in2_ptr = *s2.as_cuda_slice::<i16>()?.device_ptr() as *const c_void;
180                let elem_count = l1.shape().elem_count();
181                (d_in1_ptr, d_in2_ptr, elem_count)
182            }
183            DType::BF16 => {
184                return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise"));
185            }
186            DType::F16 => {
187                return Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise"));
188            }
189            DType::F32 => {
190                return Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise"));
191            }
192            DType::F64 => {
193                return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise"));
194            }
195            DType::F8E4M3 => {
196                return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise"));
197            }
198        };
199        let dst = match s1.dtype() {
200            DType::U8 => {
201                let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
202                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
203                unsafe {
204                    match self.op {
205                        BitWiseOpEnum::And => ffi::bitwise_and_u8(
206                            d_in1_ptr,
207                            d_in2_ptr,
208                            d_out_ptr,
209                            u32::try_from(elem_count)?,
210                        ),
211                        BitWiseOpEnum::Or => ffi::bitwise_or_u8(
212                            d_in1_ptr,
213                            d_in2_ptr,
214                            d_out_ptr,
215                            u32::try_from(elem_count)?,
216                        ),
217                        BitWiseOpEnum::Xor => ffi::bitwise_xor_u8(
218                            d_in1_ptr,
219                            d_in2_ptr,
220                            d_out_ptr,
221                            u32::try_from(elem_count)?,
222                        ),
223                    }
224                };
225                CudaStorage::wrap_cuda_slice(d_out, dev)
226            }
227            DType::U32 => {
228                let d_out = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
229                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
230                unsafe {
231                    match self.op {
232                        BitWiseOpEnum::And => ffi::bitwise_and_u32(
233                            d_in1_ptr,
234                            d_in2_ptr,
235                            d_out_ptr,
236                            u32::try_from(elem_count)?,
237                        ),
238                        BitWiseOpEnum::Or => ffi::bitwise_or_u32(
239                            d_in1_ptr,
240                            d_in2_ptr,
241                            d_out_ptr,
242                            u32::try_from(elem_count)?,
243                        ),
244                        BitWiseOpEnum::Xor => ffi::bitwise_xor_u32(
245                            d_in1_ptr,
246                            d_in2_ptr,
247                            d_out_ptr,
248                            u32::try_from(elem_count)?,
249                        ),
250                    }
251                };
252                CudaStorage::wrap_cuda_slice(d_out, dev)
253            }
254            DType::I64 => {
255                let d_out = unsafe { dev.alloc::<i64>(elem_count) }.w()?;
256                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
257                unsafe {
258                    match self.op {
259                        BitWiseOpEnum::And => ffi::bitwise_and_i64(
260                            d_in1_ptr,
261                            d_in2_ptr,
262                            d_out_ptr,
263                            u32::try_from(elem_count)?,
264                        ),
265                        BitWiseOpEnum::Or => ffi::bitwise_or_i64(
266                            d_in1_ptr,
267                            d_in2_ptr,
268                            d_out_ptr,
269                            u32::try_from(elem_count)?,
270                        ),
271                        BitWiseOpEnum::Xor => ffi::bitwise_xor_i64(
272                            d_in1_ptr,
273                            d_in2_ptr,
274                            d_out_ptr,
275                            u32::try_from(elem_count)?,
276                        ),
277                    }
278                };
279                CudaStorage::wrap_cuda_slice(d_out, dev)
280            }
281            DType::I32 => {
282                let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
283                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
284                unsafe {
285                    match self.op {
286                        BitWiseOpEnum::And => ffi::bitwise_and_i32(
287                            d_in1_ptr,
288                            d_in2_ptr,
289                            d_out_ptr,
290                            u32::try_from(elem_count)?,
291                        ),
292                        BitWiseOpEnum::Or => ffi::bitwise_or_i32(
293                            d_in1_ptr,
294                            d_in2_ptr,
295                            d_out_ptr,
296                            u32::try_from(elem_count)?,
297                        ),
298                        BitWiseOpEnum::Xor => ffi::bitwise_xor_i32(
299                            d_in1_ptr,
300                            d_in2_ptr,
301                            d_out_ptr,
302                            u32::try_from(elem_count)?,
303                        ),
304                    }
305                };
306                CudaStorage::wrap_cuda_slice(d_out, dev)
307            }
308            _ => unreachable!(),
309        };
310        Ok((dst, l1.shape().clone()))
311    }
312}
313
314#[allow(dead_code)]
315pub trait BitWiseOp {
316    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
317    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
318    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
319}
320
321impl BitWiseOp for Tensor {
322    #[cfg(feature = "metal")]
323    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
324        let original_device = rhs.device();
325        self.to_device(&candle_core::Device::Cpu)?
326            .apply_op2_no_bwd(
327                &rhs.to_device(&candle_core::Device::Cpu)?,
328                &BitWise::new(BitWiseOpEnum::And),
329            )?
330            .to_device(original_device)
331    }
332    #[cfg(not(feature = "metal"))]
333    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
334        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::And))
335    }
336
337    #[cfg(feature = "metal")]
338    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
339        let original_device = rhs.device();
340        self.to_device(&candle_core::Device::Cpu)?
341            .apply_op2_no_bwd(
342                &rhs.to_device(&candle_core::Device::Cpu)?,
343                &BitWise::new(BitWiseOpEnum::Or),
344            )?
345            .to_device(original_device)
346    }
347    #[cfg(not(feature = "metal"))]
348    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
349        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::Or))
350    }
351
352    #[cfg(feature = "metal")]
353    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
354        let original_device = rhs.device();
355        self.to_device(&candle_core::Device::Cpu)?
356            .apply_op2_no_bwd(
357                &rhs.to_device(&candle_core::Device::Cpu)?,
358                &BitWise::new(BitWiseOpEnum::Xor),
359            )?
360            .to_device(original_device)
361    }
362    #[cfg(not(feature = "metal"))]
363    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
364        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::Xor))
365    }
366}
367
368struct NonZero {}
369impl NonZero {
370    // Sequential version
371
372    fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
373        let n = layout.dims().len();
374        let mut result = Vec::new();
375        let mut indices = vec![0u32; n];
376        for (i, v) in vs.iter().enumerate() {
377            if !v.is_zero() {
378                let mut idx = i;
379                for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
380                    let d = idx % dim;
381                    indices[dim_index] = u32::try_from(d).unwrap();
382                    idx /= dim;
383                }
384                result.extend_from_slice(&indices);
385            }
386        }
387        result
388    }
389}
390
391#[cfg(feature = "cuda")]
392fn count_nonzero_cuda(
393    dtype: candle_core::DType,
394    d_in: *const c_void,
395    n: u32,
396    stream: candle_core::cuda::cudarc::driver::sys::CUstream,
397) -> u32 {
398    unsafe {
399        match dtype {
400            candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
401            candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
402            candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
403            candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
404            candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
405            candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
406            candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
407            candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
408            candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
409            candle_core::DType::F8E4M3 => todo!(),
410        }
411    }
412}
413
414#[allow(clippy::too_many_arguments)]
415#[cfg(feature = "cuda")]
416fn nonzero_cuda(
417    dtype: candle_core::DType,
418    d_in: *const c_void,
419    n: u32,
420    num_nonzero: u32,
421    dims: *const c_void,
422    num_dims: u32,
423    d_out: *mut c_void,
424    stream: candle_core::cuda::cudarc::driver::sys::CUstream,
425) {
426    unsafe {
427        match dtype {
428            candle_core::DType::U8 => {
429                ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
430            }
431            candle_core::DType::U32 => {
432                ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
433            }
434            candle_core::DType::I64 => {
435                ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
436            }
437            candle_core::DType::I32 => {
438                ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
439            }
440            candle_core::DType::I16 => {
441                ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
442            }
443            candle_core::DType::BF16 => {
444                ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
445            }
446            candle_core::DType::F16 => {
447                ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
448            }
449            candle_core::DType::F32 => {
450                ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
451            }
452            candle_core::DType::F64 => {
453                ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
454            }
455            candle_core::DType::F8E4M3 => todo!(),
456        }
457    }
458}
459
460impl CustomOp1 for NonZero {
461    fn name(&self) -> &'static str {
462        "nonzero"
463    }
464
465    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
466        if !layout.is_contiguous() {
467            return Err(Error::RequiresContiguous { op: "nonzero" });
468        }
469        let result = match storage {
470            candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
471            candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
472            candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
473            candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
474            candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
475            candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
476            candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
477            candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
478            candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
479            candle_core::CpuStorage::F8E4M3(_vs) => todo!(),
480        };
481        let index_len = layout.dims().len();
482        let result_len = result.len() / index_len;
483        let result = CpuStorage::U32(result);
484        let shape = Shape::from_dims(&[result_len, index_len]);
485        Ok((result, shape))
486    }
487    #[cfg(feature = "cuda")]
488    fn cuda_fwd(
489        &self,
490        storage: &candle_core::CudaStorage,
491        layout: &Layout,
492    ) -> Result<(candle_core::CudaStorage, Shape)> {
493        if !layout.is_contiguous() {
494            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
495        }
496        let dev = storage.device().clone();
497        let d_in = match storage.dtype() {
498            candle_core::DType::U8 => *storage.as_cuda_slice::<u8>()?.device_ptr(),
499            candle_core::DType::U32 => *storage.as_cuda_slice::<u32>()?.device_ptr(),
500            candle_core::DType::I32 => *storage.as_cuda_slice::<i32>()?.device_ptr(),
501            candle_core::DType::I16 => *storage.as_cuda_slice::<i16>()?.device_ptr(),
502            candle_core::DType::I64 => *storage.as_cuda_slice::<i64>()?.device_ptr(),
503            candle_core::DType::BF16 => *storage.as_cuda_slice::<bf16>()?.device_ptr(),
504            candle_core::DType::F16 => *storage.as_cuda_slice::<f16>()?.device_ptr(),
505            candle_core::DType::F32 => *storage.as_cuda_slice::<f32>()?.device_ptr(),
506            candle_core::DType::F64 => *storage.as_cuda_slice::<f64>()?.device_ptr(),
507            candle_core::DType::F8E4M3 => todo!(),
508        } as *const c_void;
509        let n = layout.shape().elem_count();
510
511        let num_nonzero =
512            count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?, *dev.cu_stream());
513        let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
514            .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
515        let d_out_ptr = *d_out.device_ptr() as *mut c_void;
516        let dims = layout
517            .dims()
518            .iter()
519            .map(|&x| u32::try_from(x).unwrap())
520            .collect::<Vec<u32>>();
521        let d_dims = dev
522            .htod_copy(dims)
523            .map_err(|_| Error::Msg("Failed to copy dims to device".to_string()))?;
524        let d_dims_ptr = *d_dims.device_ptr() as *const c_void;
525        nonzero_cuda(
526            storage.dtype(),
527            d_in,
528            u32::try_from(n)?,
529            num_nonzero,
530            d_dims_ptr,
531            u32::try_from(layout.dims().len())?,
532            d_out_ptr,
533            *dev.cu_stream(),
534        );
535        let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
536        let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
537        Ok((dst, shape))
538    }
539}
540
541pub trait NonZeroOp {
542    fn nonzero(&self) -> Result<Tensor>;
543}
544
545impl NonZeroOp for Tensor {
546    #[cfg(feature = "metal")]
547    fn nonzero(&self) -> Result<Tensor> {
548        if !self.is_contiguous() {
549            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
550        }
551        let original_device = self.device();
552        self.to_device(&candle_core::Device::Cpu)?
553            .apply_op1_no_bwd(&NonZero {})?
554            .to_device(original_device)
555    }
556    #[cfg(not(feature = "metal"))]
557    fn nonzero(&self) -> Result<Tensor> {
558        if !self.is_contiguous() {
559            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
560        }
561        self.apply_op1_no_bwd(&NonZero {})
562    }
563}
564
565#[allow(dead_code)]
566#[derive(Debug, Clone)]
567struct ArgSort {
568    asc: bool,
569    last_dim: usize,
570    inplace: bool,
571}
572
573impl candle_core::CustomOp1 for ArgSort {
574    fn name(&self) -> &'static str {
575        "argsort"
576    }
577
578    fn cpu_fwd(
579        &self,
580        _: &candle_core::CpuStorage,
581        _: &candle_core::Layout,
582    ) -> Result<(candle_core::CpuStorage, candle_core::Shape)> {
583        panic!("not implemented!")
584    }
585
586    #[allow(clippy::cast_possible_truncation)]
587    #[cfg(feature = "cuda")]
588    fn cuda_fwd(
589        &self,
590        storage: &candle_core::CudaStorage,
591        layout: &candle_core::Layout,
592    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
593        use candle_core::backend::BackendStorage;
594        use candle_core::cuda_backend::cudarc::driver::DevicePtr;
595        use candle_core::cuda_backend::CudaStorageSlice;
596        use candle_core::cuda_backend::WrapErr;
597        let dev = storage.device();
598        let elem_count = layout.shape().elem_count();
599        let ncols = self.last_dim as i32;
600        let nrows = elem_count as i32 / ncols;
601        let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
602
603        use std::ffi::c_void;
604
605        let src = match &storage.slice {
606            CudaStorageSlice::U8(inp) => inp.device_ptr(),
607            CudaStorageSlice::U32(inp) => inp.device_ptr(),
608            CudaStorageSlice::I64(inp) => inp.device_ptr(),
609            CudaStorageSlice::BF16(inp) => inp.device_ptr(),
610            CudaStorageSlice::F16(inp) => inp.device_ptr(),
611            CudaStorageSlice::F32(inp) => inp.device_ptr(),
612            CudaStorageSlice::F64(inp) => inp.device_ptr(),
613            _ => candle_core::bail!("Unexpected dtype in asort"),
614        };
615        let src_ptr = *src as *const c_void;
616        let dst_ptr = *dst.device_ptr() as *mut c_void;
617        let stream = *dev.cu_stream() as i64;
618        unsafe {
619            if self.asc {
620                match storage.dtype() {
621                    candle_core::DType::U8 => {
622                        ffi::asort_asc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
623                    }
624                    candle_core::DType::U32 => {
625                        ffi::asort_asc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
626                    }
627                    candle_core::DType::I64 => {
628                        ffi::asort_asc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
629                    }
630                    candle_core::DType::BF16 => {
631                        ffi::asort_asc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
632                    }
633                    candle_core::DType::F16 => {
634                        ffi::asort_asc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
635                    }
636                    candle_core::DType::F32 => {
637                        ffi::asort_asc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
638                    }
639                    candle_core::DType::F64 => {
640                        ffi::asort_asc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
641                    }
642                    _ => candle_core::bail!("Unexpected dtype in asort"),
643                }
644            } else {
645                match storage.dtype() {
646                    candle_core::DType::U8 => {
647                        ffi::asort_desc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
648                    }
649                    candle_core::DType::U32 => {
650                        ffi::asort_desc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
651                    }
652                    candle_core::DType::I64 => {
653                        ffi::asort_desc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
654                    }
655                    candle_core::DType::BF16 => {
656                        ffi::asort_desc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
657                    }
658                    candle_core::DType::F16 => {
659                        ffi::asort_desc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
660                    }
661                    candle_core::DType::F32 => {
662                        ffi::asort_desc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
663                    }
664                    candle_core::DType::F64 => {
665                        ffi::asort_desc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
666                    }
667                    _ => candle_core::bail!("Unexpected dtype in asort"),
668                }
669            }
670        }
671        let dst_ret = candle_core::cuda_backend::CudaStorage {
672            slice: CudaStorageSlice::U32(dst),
673            device: dev.clone(),
674        };
675        Ok((dst_ret, layout.shape().clone()))
676    }
677}
678
679#[allow(dead_code)]
680pub trait ArgSortOp {
681    fn arg_sort(&self, asc: bool) -> Result<Tensor>;
682    fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)>;
683}
684
685impl ArgSortOp for Tensor {
686    /// Returns the indices that sort the tensor along the last dimension.
687    ///
688    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
689    /// descending order. The sort is unstable so there is no guarantees on the final order when it
690    /// comes to ties.
691    fn arg_sort(&self, asc: bool) -> Result<Tensor> {
692        if !self.is_contiguous() {
693            return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
694        }
695        let last_dim = match self.dims().last() {
696            Some(last_dim) => *last_dim,
697            None => candle_core::bail!("empty last-dim in arg-sort"),
698        };
699        // No need for a backward pass for arg sort.
700        self.apply_op1_no_bwd(&ArgSort {
701            asc,
702            last_dim,
703            inplace: false,
704        })
705    }
706
707    /// Sorts the tensor along the last dimension, returns the sorted tensor together with the
708    /// sorted indexes.
709    ///
710    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
711    /// descending order. The sort is unstable so there is no guarantees on the final order when it
712    /// comes to ties.
713    fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)> {
714        if !self.is_contiguous() {
715            return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
716        }
717        let last_dim = match self.dims().last() {
718            Some(last_dim) => *last_dim,
719            None => candle_core::bail!("empty last-dim in arg-sort"),
720        };
721        let sorted = self.copy()?;
722
723        let asort = sorted.apply_op1_no_bwd(&ArgSort {
724            asc,
725            last_dim,
726            inplace: true,
727        })?;
728
729        Ok((sorted, asort))
730    }
731}
732
733#[allow(dead_code)]
734pub struct TopKOutput {
735    pub values: Tensor,
736    pub indices: Tensor,
737}
738
739pub trait TopKLastDimOp {
740    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
741    /// This expects a contiguous tensor.
742    /// Note: this implements torch.topk with sorted=True.
743    fn topk(&self, topk: usize) -> Result<TopKOutput>;
744
745    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
746    /// This expects a contiguous tensor.
747    /// Note: this implements torch.topk with sorted=False.
748    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
749}
750
751impl TopKLastDimOp for Tensor {
752    fn topk(&self, topk: usize) -> Result<TopKOutput> {
753        // Sorted descending
754        #[cfg(feature = "cuda")]
755        let (values, sorted_indices) = self.sort(false)?;
756        #[cfg(not(feature = "cuda"))]
757        let (values, sorted_indices) = self.sort_last_dim(false)?;
758        let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
759        let topk_values = values.narrow(D::Minus1, 0, topk)?.contiguous()?;
760        Ok(TopKOutput {
761            values: topk_values,
762            indices: topk_indices,
763        })
764    }
765
766    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
767        // Sorted descending
768        let TopKOutput { values, indices } = self.topk(topk)?;
769        // Reorder the indices ascending
770        #[cfg(feature = "cuda")]
771        let reorder_indices = indices.arg_sort(true)?;
772        #[cfg(not(feature = "cuda"))]
773        let reorder_indices = indices.arg_sort_last_dim(true)?;
774        let topk_indices_unsorted = indices
775            .to_dtype(DType::F32)?
776            .gather(&reorder_indices, D::Minus1)?
777            .to_dtype(DType::U32)?;
778        let topk_values_unsorted = values.gather(&reorder_indices, D::Minus1)?;
779        Ok(TopKOutput {
780            values: topk_values_unsorted,
781            indices: topk_indices_unsorted,
782        })
783    }
784}
785
786pub trait RepeatInterleaveOp {
787    fn repeat_interleave(&self, repeats: usize, dim: usize) -> Result<Tensor>;
788    fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor>;
789}
790
791impl RepeatInterleaveOp for Tensor {
792    fn repeat_interleave(&self, repeats: usize, dim: usize) -> Result<Tensor> {
793        // For metal
794        assert!(self.dtype().is_float());
795        #[allow(clippy::cast_possible_truncation)]
796        let indices = Tensor::new(
797            (0..self.dim(dim)?)
798                .flat_map(|i| vec![i as u32; repeats])
799                .collect::<Vec<_>>(),
800            self.device(),
801        )?;
802        self.index_select(&indices, dim)
803    }
804
805    fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor> {
806        let xs = self.flatten_all()?;
807        if repeats.len() != xs.dim(0)? {
808            candle_core::bail!(
809                "repeats ({}) must match flattened self length ({})",
810                repeats.len(),
811                xs.dim(0)?
812            );
813        }
814        #[allow(clippy::cast_possible_truncation)]
815        let indices = Tensor::new(
816            (0..xs.dim(0)?)
817                .flat_map(|i| vec![i as u32; repeats[i] as usize])
818                .collect::<Vec<_>>(),
819            xs.device(),
820        )?;
821        xs.index_select(&indices, 0)
822    }
823}
824
825pub trait SplitOp {
826    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
827}
828
829impl SplitOp for Tensor {
830    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
831        let dim = dim.to_index(self.shape(), "split")?;
832        let mut split_res = Vec::new();
833        let mut index = 0;
834        for split in splits {
835            split_res.push(self.narrow(dim, index, *split)?);
836            index += *split;
837        }
838        Ok(split_res)
839    }
840}
841
842pub trait BincountOp {
843    fn bincount(&self, minlength: u32) -> Result<Vec<u32>>;
844}
845
846fn bincount(values: &[u32], minlength: u32) -> Vec<u32> {
847    // let max_val = values.iter().max().copied().unwrap_or(0);
848    // let result_len = (max_val + 1).max(minlength);
849    // values.iter().fold(
850    //     // Start with a histogram vector of zeros.
851    //     vec![0u32; result_len as usize],
852    //     // For each value, update the histogram.
853    //     |mut histogram, &value| {
854    //         histogram[value as usize] += 1;
855    //         histogram
856    //     },
857    // )
858
859    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
860
861    // Early return if there are no values.
862    if values.is_empty() {
863        return vec![0u32; minlength as usize];
864    }
865
866    // Compute the maximum value in parallel.
867    // SAFETY: we know `values` is nonempty.
868    let max_val = *values.par_iter().max().unwrap();
869
870    // The histogram length must cover all observed values as well as `minlength`.
871    let result_len = (max_val + 1).max(minlength) as usize;
872
873    // Build per-thread histograms in parallel.
874    // We use unsafe indexing to eliminate bounds checks in the inner loop.
875    values
876        .par_iter()
877        .fold(
878            || vec![0u32; result_len],
879            |mut local_hist, &v| {
880                // SAFETY: v is guaranteed to be <= max_val, so it is in bounds.
881                unsafe {
882                    *local_hist.get_unchecked_mut(v as usize) += 1;
883                }
884                local_hist
885            },
886        )
887        // Merge the per-thread histograms in parallel.
888        .reduce(
889            || vec![0u32; result_len],
890            |mut global_hist, local_hist| {
891                for i in 0..result_len {
892                    // SAFETY: we know local histogram is at least result_len, as is global_hist
893                    unsafe {
894                        *global_hist.get_unchecked_mut(i) += local_hist.get_unchecked(i);
895                    }
896                }
897                global_hist
898            },
899        )
900}
901
902impl BincountOp for Tensor {
903    fn bincount(&self, minlength: u32) -> Result<Vec<u32>> {
904        let values = self.to_vec1::<u32>()?;
905
906        Ok(bincount(&values, minlength))
907    }
908}
909
910mod tests {
911    #[test]
912    fn test_topk() {
913        use crate::ops::{TopKLastDimOp, TopKOutput};
914        use candle_core::Tensor;
915        let device = candle_core::Device::Cpu;
916        //  [[1, 3, 5],
917        //   [2, 4, 6]]
918        let x = Tensor::arange(1f32, 7f32, &device)
919            .unwrap()
920            .reshape((3, 2))
921            .unwrap()
922            .t()
923            .unwrap()
924            .contiguous()
925            .unwrap();
926        let TopKOutput { values, indices } = x.topk(2).unwrap();
927        assert_eq!(
928            x.to_vec2::<f32>().unwrap(),
929            vec![vec![1f32, 3f32, 5f32], vec![2f32, 4f32, 6f32]]
930        );
931        assert_eq!(
932            values.to_vec2::<f32>().unwrap(),
933            vec![vec![5f32, 3f32], vec![6f32, 4f32]]
934        );
935        assert_eq!(
936            indices.to_vec2::<u32>().unwrap(),
937            vec![vec![2u32, 1u32], vec![2u32, 1u32]]
938        );
939    }
940
941    #[test]
942    fn test_nonzero_cpu() {
943        use crate::ops::NonZeroOp;
944        use candle_core::Tensor;
945        let device = candle_core::Device::Cpu;
946        let a = Tensor::from_vec(
947            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
948            &[2, 4],
949            &device,
950        )
951        .unwrap();
952        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
953        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
954    }
955
956    #[cfg(feature = "cuda")]
957    #[test]
958    fn test_nonzero_cuda() {
959        use crate::ops::NonZeroOp;
960        use candle_core::Tensor;
961        let device = candle_core::Device::new_cuda(0).unwrap();
962        let a = Tensor::from_vec(
963            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
964            &[2, 4],
965            &device,
966        )
967        .unwrap();
968        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
969        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
970    }
971
972    #[test]
973    fn test_bitwise_and_cpu() {
974        use crate::ops::BitWiseOp;
975        use candle_core::Tensor;
976        let device = candle_core::Device::Cpu;
977        let a =
978            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
979        let b =
980            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
981        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
982        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
983    }
984
985    #[cfg(feature = "cuda")]
986    #[test]
987    fn test_bitwise_and_cuda() {
988        use crate::ops::BitWiseOp;
989        use candle_core::Tensor;
990        let device = candle_core::Device::new_cuda(0).unwrap();
991        let a =
992            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
993        let b =
994            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
995        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
996        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
997    }
998
999    #[test]
1000    fn test_bitwise_or_cpu() {
1001        use crate::ops::BitWiseOp;
1002        use candle_core::Tensor;
1003        let device = candle_core::Device::Cpu;
1004        let a =
1005            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1006        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1007        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1008        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1009    }
1010
1011    #[cfg(feature = "cuda")]
1012    #[test]
1013    fn test_bitwise_or_cuda() {
1014        use crate::ops::BitWiseOp;
1015        use candle_core::Tensor;
1016        let device = candle_core::Device::new_cuda(0).unwrap();
1017        let a =
1018            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1019        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1020        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1021        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1022    }
1023
1024    #[test]
1025    fn test_bitwise_xor_cpu() {
1026        use crate::ops::BitWiseOp;
1027        use candle_core::Tensor;
1028        let device = candle_core::Device::Cpu;
1029        let a =
1030            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1031        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1032        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1033        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1034    }
1035
1036    #[cfg(feature = "cuda")]
1037    #[test]
1038    fn test_bitwise_xor_cuda() {
1039        use crate::ops::BitWiseOp;
1040        use candle_core::Tensor;
1041        let device = candle_core::Device::new_cuda(0).unwrap();
1042        let a =
1043            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1044        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1045        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1046        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1047    }
1048
1049    #[test]
1050    fn test_nonzero_and() {
1051        use crate::ops::{BitWiseOp, NonZeroOp};
1052        use candle_core::{Device, Tensor};
1053
1054        let input1 = Tensor::from_vec(
1055            vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1056            (10,),
1057            &Device::Cpu,
1058        )
1059        .unwrap();
1060        let input2 = Tensor::from_vec(
1061            vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1062            (10,),
1063            &Device::Cpu,
1064        )
1065        .unwrap();
1066        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1067
1068        let lt = input.lt(0.0).unwrap();
1069        let gt = input.gt(-10.0).unwrap();
1070        let res = lt
1071            .bitwise_and(&gt)
1072            .unwrap()
1073            .nonzero()
1074            .unwrap()
1075            .to_vec2::<u32>()
1076            .unwrap();
1077
1078        assert_eq!(
1079            res,
1080            [
1081                [0, 3],
1082                [0, 4],
1083                [0, 5],
1084                [0, 6],
1085                [1, 0],
1086                [1, 3],
1087                [1, 5],
1088                [1, 6]
1089            ]
1090        );
1091    }
1092
1093    #[cfg(feature = "cuda")]
1094    #[test]
1095    fn nonzero_and_cuda() {
1096        use crate::ops::{BitWiseOp, NonZeroOp};
1097        use candle_core::{Device, Tensor};
1098
1099        let device = Device::new_cuda(0).unwrap();
1100        let input1 =
1101            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1102        let input2 =
1103            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1104        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1105
1106        let lt = input.lt(0.0).unwrap();
1107        let gt = input.gt(-10.0).unwrap();
1108        let res = lt
1109            .bitwise_and(&gt)
1110            .unwrap()
1111            .nonzero()
1112            .unwrap()
1113            .to_vec2::<u32>()
1114            .unwrap();
1115
1116        assert_eq!(
1117            res,
1118            [
1119                [0, 3],
1120                [0, 4],
1121                [0, 5],
1122                [0, 6],
1123                [1, 0],
1124                [1, 3],
1125                [1, 5],
1126                [1, 6]
1127            ]
1128        );
1129    }
1130
1131    #[test]
1132    fn test_repeat_interleave() -> candle_core::Result<()> {
1133        use crate::ops::RepeatInterleaveOp;
1134        use candle_core::{Device, Tensor};
1135
1136        let input = Tensor::new(
1137            vec![vec![vec![1f32, 2., 3.], vec![4f32, 5., 6.]]],
1138            &Device::Cpu,
1139        )?;
1140
1141        let repeat_interleaved = input.repeat_interleave(2, 2)?;
1142        assert_eq!(
1143            repeat_interleaved.to_vec3::<f32>()?,
1144            vec![vec![
1145                vec![1., 1., 2., 2., 3., 3.],
1146                vec![4., 4., 5., 5., 6., 6.]
1147            ]]
1148        );
1149
1150        Ok(())
1151    }
1152
1153    #[test]
1154    fn test_repeat_interleave_flat() -> candle_core::Result<()> {
1155        use crate::ops::RepeatInterleaveOp;
1156        use candle_core::{Device, Tensor};
1157
1158        let input = Tensor::new(vec![1., 2., 3., 4.], &Device::Cpu)?;
1159
1160        let repeat_interleaved = input.repeat_interleave_flat(vec![1u32, 2u32, 3u32, 4u32])?;
1161        assert_eq!(
1162            repeat_interleaved.to_vec1::<f64>()?,
1163            vec![1., 2., 2., 3., 3., 3., 4., 4., 4., 4.]
1164        );
1165
1166        Ok(())
1167    }
1168}