mistralrs_quant/utils/
ops.rs

1use candle_core::{
2    backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3    Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8    fmt::Display,
9    ops::{BitAnd, BitOr, BitXor, Not, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::{ffi, slice_ptr};
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19#[cfg(feature = "metal")]
20use crate::metal_kernels::SortScratchCache; // re‑export for clarity
21#[cfg(feature = "metal")]
22use std::sync::OnceLock;
23
24#[cfg(feature = "metal")]
25static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
26
27struct Leftshift(usize);
28
29impl Leftshift {
30    fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
31        let offset = T::from_f64(self.0 as f64);
32        vs.into_par_iter().map(|v| *v << offset).collect()
33    }
34}
35
36impl CustomOp1 for Leftshift {
37    fn name(&self) -> &'static str {
38        "left"
39    }
40
41    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
42        match s1 {
43            CpuStorage::U8(vs1) => {
44                let vs1 = match l1.contiguous_offsets() {
45                    Some((a, b)) => &vs1[a..b],
46                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
47                };
48                let result = self.leftshift(vs1);
49                let result = CpuStorage::U8(result);
50                Ok((result, l1.shape().clone()))
51            }
52            CpuStorage::I16(vs1) => {
53                let vs1 = match l1.contiguous_offsets() {
54                    Some((a, b)) => &vs1[a..b],
55                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
56                };
57                let result = self.leftshift(vs1);
58                let result = CpuStorage::I16(result);
59                Ok((result, l1.shape().clone()))
60            }
61            CpuStorage::U32(vs1) => {
62                let vs1 = match l1.contiguous_offsets() {
63                    Some((a, b)) => &vs1[a..b],
64                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
65                };
66                let result = self.leftshift(vs1);
67                let result = CpuStorage::U32(result);
68                Ok((result, l1.shape().clone()))
69            }
70            CpuStorage::I64(vs1) => {
71                let vs1 = match l1.contiguous_offsets() {
72                    Some((a, b)) => &vs1[a..b],
73                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
74                };
75                let result = self.leftshift(vs1);
76                let result = CpuStorage::I64(result);
77                Ok((result, l1.shape().clone()))
78            }
79            CpuStorage::I32(vs1) => {
80                let vs1 = match l1.contiguous_offsets() {
81                    Some((a, b)) => &vs1[a..b],
82                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
83                };
84                let result = self.leftshift(vs1);
85                let result = CpuStorage::I32(result);
86                Ok((result, l1.shape().clone()))
87            }
88            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "leftshift")),
89        }
90    }
91
92    #[cfg(feature = "cuda")]
93    fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
94        if !l1.is_contiguous() {
95            candle_core::bail!("Input tensor s1 must be contiguous");
96        }
97        let dev = s1.device().clone();
98        let (d_in1_ptr, _d_guard, elem_count) = match s1.dtype() {
99            DType::U8 => {
100                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
101                let elem_count = l1.shape().elem_count();
102                (d_in1 as *const c_void, d_in1_guard, elem_count)
103            }
104            DType::I32 => {
105                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
106                let elem_count = l1.shape().elem_count();
107                (d_in1 as *const c_void, d_in1_guard, elem_count)
108            }
109            other => {
110                return Err(Error::UnsupportedDTypeForOp(other, "leftshift"));
111            }
112        };
113        let dst = match s1.dtype() {
114            DType::U8 => {
115                let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
116                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
117                unsafe {
118                    ffi::leftshift_u8(
119                        d_in1_ptr,
120                        d_out_ptr as *mut std::ffi::c_void,
121                        u32::try_from(elem_count)?,
122                        self.0 as i32,
123                    )
124                };
125                drop(d_out_guard);
126                CudaStorage::wrap_cuda_slice(d_out, dev)
127            }
128            DType::I32 => {
129                let d_out = unsafe { dev.alloc::<i32>(elem_count) }?;
130                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
131                unsafe {
132                    ffi::leftshift_i32(
133                        d_in1_ptr,
134                        d_out_ptr as *mut std::ffi::c_void,
135                        u32::try_from(elem_count)?,
136                        self.0 as i32,
137                    )
138                };
139                drop(d_out_guard);
140                CudaStorage::wrap_cuda_slice(d_out, dev)
141            }
142            _ => unreachable!(),
143        };
144        Ok((dst, l1.shape().clone()))
145    }
146
147    #[cfg(feature = "metal")]
148    fn metal_fwd(
149        &self,
150        s1: &candle_core::MetalStorage,
151        l1: &Layout,
152    ) -> Result<(candle_core::MetalStorage, Shape)> {
153        if !l1.is_contiguous() {
154            candle_core::bail!("Input tensor s1 must be contiguous");
155        }
156
157        let command_buffer = s1.device().command_buffer()?;
158        command_buffer.set_label("bitwise-leftshift");
159
160        let device = s1.device();
161
162        let out_shape = l1.shape().clone();
163
164        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
165
166        crate::metal_kernels::call_bitwise_leftshift(
167            device.device(),
168            &command_buffer,
169            &crate::metal_kernels::Kernels::new(),
170            s1.dtype(),
171            s1.buffer(),
172            l1.start_offset(),
173            self.0 as u32,
174            out_shape.elem_count(),
175            &output,
176        )
177        .map_err(candle_core::Error::wrap)?;
178
179        let newstorage = candle_core::MetalStorage::new(
180            output,
181            device.clone(),
182            out_shape.elem_count(),
183            s1.dtype(),
184        );
185        Ok((newstorage, out_shape))
186    }
187}
188
189#[allow(dead_code)]
190pub trait LeftshiftOp {
191    fn leftshift(&self, n: usize) -> Result<Tensor>;
192}
193
194impl LeftshiftOp for Tensor {
195    fn leftshift(&self, n: usize) -> Result<Tensor> {
196        self.apply_op1_no_bwd(&Leftshift(n))
197    }
198}
199
200pub enum BitWiseBinaryOpEnum {
201    And,
202    Or,
203    Xor,
204}
205
206impl Display for BitWiseBinaryOpEnum {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        match self {
209            BitWiseBinaryOpEnum::And => write!(f, "And"),
210            BitWiseBinaryOpEnum::Or => write!(f, "Or"),
211            BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
212        }
213    }
214}
215
216pub enum BitWiseUnaryOpEnum {
217    Not,
218}
219
220impl Display for BitWiseUnaryOpEnum {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            BitWiseUnaryOpEnum::Not => write!(f, "Not"),
224        }
225    }
226}
227
228struct BitWise {
229    pub op: BitWiseBinaryOpEnum,
230}
231
232impl BitWise {
233    pub fn new(op: BitWiseBinaryOpEnum) -> Self {
234        Self { op }
235    }
236
237    fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
238        &self,
239        vs1: &[T],
240        vs2: &[T],
241    ) -> Vec<T> {
242        vs1.into_par_iter()
243            .zip_eq(vs2)
244            .map(|(v1, v2)| match self.op {
245                BitWiseBinaryOpEnum::And => *v1 & *v2,
246                BitWiseBinaryOpEnum::Or => *v1 | *v2,
247                BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
248            })
249            .collect()
250    }
251}
252
253impl CustomOp2 for BitWise {
254    fn name(&self) -> &'static str {
255        "bitwise"
256    }
257
258    fn cpu_fwd(
259        &self,
260        s1: &CpuStorage,
261        l1: &Layout,
262        s2: &CpuStorage,
263        l2: &Layout,
264    ) -> Result<(CpuStorage, Shape)> {
265        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
266            return Err(Error::ShapeMismatchBinaryOp {
267                lhs: l1.shape().clone(),
268                rhs: l2.shape().clone(),
269                op: "bitwise-op",
270            });
271        }
272        if s1.dtype() != s2.dtype() {
273            return Err(Error::DTypeMismatchBinaryOp {
274                lhs: s1.dtype(),
275                rhs: s2.dtype(),
276                op: "bitwise-op",
277            });
278        }
279        if !l1.is_contiguous() {
280            candle_core::bail!("Input tensor s1 must be contiguous");
281        }
282        if !l2.is_contiguous() {
283            candle_core::bail!("Input tensor s2 must be contiguous");
284        }
285
286        match s1 {
287            CpuStorage::U8(vs1) => {
288                let vs2 = s2.as_slice::<u8>().unwrap();
289                let vs1 = match l1.contiguous_offsets() {
290                    Some((a, b)) => &vs1[a..b],
291                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
292                };
293                let vs2 = match l2.contiguous_offsets() {
294                    Some((a, b)) => &vs2[a..b],
295                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
296                };
297                let result = self.bitwise(vs1, vs2);
298                let result = CpuStorage::U8(result);
299                Ok((result, l1.shape().clone()))
300            }
301            CpuStorage::U32(vs1) => {
302                let vs2 = s2.as_slice::<u32>().unwrap();
303                let vs1 = match l1.contiguous_offsets() {
304                    Some((a, b)) => &vs1[a..b],
305                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
306                };
307                let vs2 = match l2.contiguous_offsets() {
308                    Some((a, b)) => &vs2[a..b],
309                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
310                };
311                let result = self.bitwise(vs1, vs2);
312                let result = CpuStorage::U32(result);
313                Ok((result, l1.shape().clone()))
314            }
315            CpuStorage::I64(vs1) => {
316                let vs2 = s2.as_slice::<i64>().unwrap();
317                let vs1 = match l1.contiguous_offsets() {
318                    Some((a, b)) => &vs1[a..b],
319                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
320                };
321                let vs2 = match l2.contiguous_offsets() {
322                    Some((a, b)) => &vs2[a..b],
323                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
324                };
325                let result = self.bitwise(vs1, vs2);
326                let result = CpuStorage::I64(result);
327                Ok((result, l1.shape().clone()))
328            }
329            CpuStorage::I16(vs1) => {
330                let vs2 = s2.as_slice::<i16>().unwrap();
331                let vs1 = match l1.contiguous_offsets() {
332                    Some((a, b)) => &vs1[a..b],
333                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
334                };
335                let vs2 = match l2.contiguous_offsets() {
336                    Some((a, b)) => &vs2[a..b],
337                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
338                };
339                let result = self.bitwise(vs1, vs2);
340                let result = CpuStorage::I16(result);
341                Ok((result, l1.shape().clone()))
342            }
343            CpuStorage::I32(vs1) => {
344                let vs2 = s2.as_slice::<i32>().unwrap();
345                let vs1 = match l1.contiguous_offsets() {
346                    Some((a, b)) => &vs1[a..b],
347                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
348                };
349                let vs2 = match l2.contiguous_offsets() {
350                    Some((a, b)) => &vs2[a..b],
351                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
352                };
353                let result = self.bitwise(vs1, vs2);
354                let result = CpuStorage::I32(result);
355                Ok((result, l1.shape().clone()))
356            }
357            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
358        }
359    }
360
361    #[cfg(feature = "cuda")]
362    fn cuda_fwd(
363        &self,
364        s1: &CudaStorage,
365        l1: &Layout,
366        s2: &CudaStorage,
367        l2: &Layout,
368    ) -> Result<(CudaStorage, Shape)> {
369        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
370            return Err(Error::ShapeMismatchBinaryOp {
371                lhs: l1.shape().clone(),
372                rhs: l2.shape().clone(),
373                op: "bitwise-op",
374            });
375        }
376        if s1.dtype() != s2.dtype() {
377            return Err(Error::DTypeMismatchBinaryOp {
378                lhs: s1.dtype(),
379                rhs: s2.dtype(),
380                op: "bitwise-op",
381            });
382        }
383        if !l1.is_contiguous() {
384            candle_core::bail!("Input tensor s1 must be contiguous");
385        }
386        if !l2.is_contiguous() {
387            candle_core::bail!("Input tensor s2 must be contiguous");
388        }
389
390        let dev = s1.device().clone();
391        let (d_in1_ptr, d_in2_ptr, _d_in1_guard, _d_in2_guard, elem_count) = match s1.dtype() {
392            DType::U8 => {
393                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
394                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u8>()?, l2.start_offset());
395                let elem_count = l1.shape().elem_count();
396                (
397                    d_in1 as *const std::ffi::c_void,
398                    d_in2 as *const std::ffi::c_void,
399                    d_in1_guard,
400                    d_in2_guard,
401                    elem_count,
402                )
403            }
404            DType::U32 => {
405                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u32>()?, l1.start_offset());
406                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u32>()?, l2.start_offset());
407                let elem_count = l1.shape().elem_count();
408                (
409                    d_in1 as *const std::ffi::c_void,
410                    d_in2 as *const std::ffi::c_void,
411                    d_in1_guard,
412                    d_in2_guard,
413                    elem_count,
414                )
415            }
416            DType::I64 => {
417                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i64>()?, l1.start_offset());
418                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i64>()?, l2.start_offset());
419                let elem_count = l1.shape().elem_count();
420                (
421                    d_in1 as *const std::ffi::c_void,
422                    d_in2 as *const std::ffi::c_void,
423                    d_in1_guard,
424                    d_in2_guard,
425                    elem_count,
426                )
427            }
428            DType::I32 => {
429                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
430                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i32>()?, l2.start_offset());
431                let elem_count = l1.shape().elem_count();
432                (
433                    d_in1 as *const std::ffi::c_void,
434                    d_in2 as *const std::ffi::c_void,
435                    d_in1_guard,
436                    d_in2_guard,
437                    elem_count,
438                )
439            }
440            DType::I16 => {
441                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i16>()?, l1.start_offset());
442                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i16>()?, l2.start_offset());
443                let elem_count = l1.shape().elem_count();
444                (
445                    d_in1 as *const std::ffi::c_void,
446                    d_in2 as *const std::ffi::c_void,
447                    d_in1_guard,
448                    d_in2_guard,
449                    elem_count,
450                )
451            }
452            other => {
453                return Err(Error::UnsupportedDTypeForOp(other, "bitwise"));
454            }
455        };
456        let dst = match s1.dtype() {
457            DType::U8 => {
458                let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
459                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
460                unsafe {
461                    match self.op {
462                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
463                            d_in1_ptr,
464                            d_in2_ptr,
465                            d_out_ptr as *mut c_void,
466                            u32::try_from(elem_count)?,
467                        ),
468                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
469                            d_in1_ptr,
470                            d_in2_ptr,
471                            d_out_ptr as *mut c_void,
472                            u32::try_from(elem_count)?,
473                        ),
474                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
475                            d_in1_ptr,
476                            d_in2_ptr,
477                            d_out_ptr as *mut c_void,
478                            u32::try_from(elem_count)?,
479                        ),
480                    }
481                };
482                drop(d_out_guard);
483                CudaStorage::wrap_cuda_slice(d_out, dev)
484            }
485            DType::U32 => {
486                let d_out = unsafe { dev.alloc::<u32>(elem_count) }?;
487                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
488                unsafe {
489                    match self.op {
490                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
491                            d_in1_ptr,
492                            d_in2_ptr,
493                            d_out_ptr as *mut c_void,
494                            u32::try_from(elem_count)?,
495                        ),
496                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
497                            d_in1_ptr,
498                            d_in2_ptr,
499                            d_out_ptr as *mut c_void,
500                            u32::try_from(elem_count)?,
501                        ),
502                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
503                            d_in1_ptr,
504                            d_in2_ptr,
505                            d_out_ptr as *mut c_void,
506                            u32::try_from(elem_count)?,
507                        ),
508                    }
509                };
510                drop(d_out_guard);
511                CudaStorage::wrap_cuda_slice(d_out, dev)
512            }
513            DType::I64 => {
514                let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
515                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
516                unsafe {
517                    match self.op {
518                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
519                            d_in1_ptr,
520                            d_in2_ptr,
521                            d_out_ptr as *mut c_void,
522                            u32::try_from(elem_count)?,
523                        ),
524                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
525                            d_in1_ptr,
526                            d_in2_ptr,
527                            d_out_ptr as *mut c_void,
528                            u32::try_from(elem_count)?,
529                        ),
530                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
531                            d_in1_ptr,
532                            d_in2_ptr,
533                            d_out_ptr as *mut c_void,
534                            u32::try_from(elem_count)?,
535                        ),
536                    }
537                };
538                drop(d_out_guard);
539                CudaStorage::wrap_cuda_slice(d_out, dev)
540            }
541            DType::I32 => {
542                let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
543                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
544                unsafe {
545                    match self.op {
546                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
547                            d_in1_ptr,
548                            d_in2_ptr,
549                            d_out_ptr as *mut c_void,
550                            u32::try_from(elem_count)?,
551                        ),
552                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
553                            d_in1_ptr,
554                            d_in2_ptr,
555                            d_out_ptr as *mut c_void,
556                            u32::try_from(elem_count)?,
557                        ),
558                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
559                            d_in1_ptr,
560                            d_in2_ptr,
561                            d_out_ptr as *mut c_void,
562                            u32::try_from(elem_count)?,
563                        ),
564                    }
565                };
566                drop(d_out_guard);
567                CudaStorage::wrap_cuda_slice(d_out, dev)
568            }
569            _ => unreachable!(),
570        };
571        Ok((dst, l1.shape().clone()))
572    }
573
574    #[cfg(feature = "metal")]
575    fn metal_fwd(
576        &self,
577        s1: &candle_core::MetalStorage,
578        l1: &Layout,
579        s2: &candle_core::MetalStorage,
580        l2: &Layout,
581    ) -> Result<(candle_core::MetalStorage, Shape)> {
582        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
583            return Err(Error::ShapeMismatchBinaryOp {
584                lhs: l1.shape().clone(),
585                rhs: l2.shape().clone(),
586                op: "bitwise-op",
587            });
588        }
589        if s1.dtype() != s2.dtype() {
590            return Err(Error::DTypeMismatchBinaryOp {
591                lhs: s1.dtype(),
592                rhs: s2.dtype(),
593                op: "bitwise-op",
594            });
595        }
596        if !l1.is_contiguous() {
597            candle_core::bail!("Input tensor s1 must be contiguous");
598        }
599        if !l2.is_contiguous() {
600            candle_core::bail!("Input tensor s2 must be contiguous");
601        }
602
603        let command_buffer = s1.device().command_buffer()?;
604        command_buffer.set_label("bitwise-op");
605
606        let device = s1.device();
607
608        let out_shape = l1.shape().clone();
609
610        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
611
612        match self.op {
613            BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
614                device.device(),
615                &command_buffer,
616                &crate::metal_kernels::Kernels::new(),
617                s1.dtype(),
618                s1.buffer(),
619                s2.buffer(),
620                l1.start_offset() * s1.dtype().size_in_bytes(),
621                l2.start_offset() * s2.dtype().size_in_bytes(),
622                out_shape.elem_count(),
623                &output,
624            )
625            .map_err(candle_core::Error::wrap)?,
626            BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
627                device.device(),
628                &command_buffer,
629                &crate::metal_kernels::Kernels::new(),
630                s1.dtype(),
631                s1.buffer(),
632                s2.buffer(),
633                l1.start_offset() * s1.dtype().size_in_bytes(),
634                l2.start_offset() * s2.dtype().size_in_bytes(),
635                out_shape.elem_count(),
636                &output,
637            )
638            .map_err(candle_core::Error::wrap)?,
639            BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
640                device.device(),
641                &command_buffer,
642                &crate::metal_kernels::Kernels::new(),
643                s1.dtype(),
644                s1.buffer(),
645                s2.buffer(),
646                l1.start_offset() * s1.dtype().size_in_bytes(),
647                l2.start_offset() * s2.dtype().size_in_bytes(),
648                out_shape.elem_count(),
649                &output,
650            )
651            .map_err(candle_core::Error::wrap)?,
652        }
653
654        let newstorage = candle_core::MetalStorage::new(
655            output,
656            device.clone(),
657            out_shape.elem_count(),
658            s1.dtype(),
659        );
660        Ok((newstorage, out_shape))
661    }
662}
663
664struct BitWiseUnary {
665    pub op: BitWiseUnaryOpEnum,
666}
667
668impl BitWiseUnary {
669    pub fn new(op: BitWiseUnaryOpEnum) -> Self {
670        Self { op }
671    }
672
673    fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
674        vs1.into_par_iter()
675            .map(|v1| match self.op {
676                BitWiseUnaryOpEnum::Not => !*v1,
677            })
678            .collect()
679    }
680}
681
682impl CustomOp1 for BitWiseUnary {
683    fn name(&self) -> &'static str {
684        "bitwise-unary"
685    }
686
687    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
688        if !l1.is_contiguous() {
689            candle_core::bail!("Input tensor s1 must be contiguous");
690        }
691
692        match s1 {
693            CpuStorage::U8(vs1) => {
694                let vs1 = match l1.contiguous_offsets() {
695                    Some((a, b)) => &vs1[a..b],
696                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
697                };
698                let result = self.bitwise(vs1);
699                let result = CpuStorage::U8(result);
700                Ok((result, l1.shape().clone()))
701            }
702            CpuStorage::U32(vs1) => {
703                let vs1 = match l1.contiguous_offsets() {
704                    Some((a, b)) => &vs1[a..b],
705                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
706                };
707                let result = self.bitwise(vs1);
708                let result = CpuStorage::U32(result);
709                Ok((result, l1.shape().clone()))
710            }
711            CpuStorage::I64(vs1) => {
712                let vs1 = match l1.contiguous_offsets() {
713                    Some((a, b)) => &vs1[a..b],
714                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
715                };
716                let result = self.bitwise(vs1);
717                let result = CpuStorage::I64(result);
718                Ok((result, l1.shape().clone()))
719            }
720            CpuStorage::I16(vs1) => {
721                let vs1 = match l1.contiguous_offsets() {
722                    Some((a, b)) => &vs1[a..b],
723                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
724                };
725                let result = self.bitwise(vs1);
726                let result = CpuStorage::I16(result);
727                Ok((result, l1.shape().clone()))
728            }
729            CpuStorage::I32(vs1) => {
730                let vs1 = match l1.contiguous_offsets() {
731                    Some((a, b)) => &vs1[a..b],
732                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
733                };
734                let result = self.bitwise(vs1);
735                let result = CpuStorage::I32(result);
736                Ok((result, l1.shape().clone()))
737            }
738            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
739        }
740    }
741
742    #[cfg(feature = "cuda")]
743    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
744        todo!()
745    }
746
747    #[cfg(feature = "metal")]
748    fn metal_fwd(
749        &self,
750        s1: &candle_core::MetalStorage,
751        l1: &Layout,
752    ) -> Result<(candle_core::MetalStorage, Shape)> {
753        if !l1.is_contiguous() {
754            candle_core::bail!("Input tensor s1 must be contiguous");
755        }
756
757        let command_buffer = s1.device().command_buffer()?;
758        command_buffer.set_label("bitwise-unary-op");
759
760        let device = s1.device();
761
762        let out_shape = l1.shape().clone();
763
764        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
765
766        match self.op {
767            BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
768                device.device(),
769                &command_buffer,
770                &crate::metal_kernels::Kernels::new(),
771                s1.dtype(),
772                s1.buffer(),
773                l1.start_offset() * s1.dtype().size_in_bytes(),
774                out_shape.elem_count(),
775                &output,
776            )
777            .map_err(candle_core::Error::wrap)?,
778        }
779
780        let newstorage = candle_core::MetalStorage::new(
781            output,
782            device.clone(),
783            out_shape.elem_count(),
784            s1.dtype(),
785        );
786        Ok((newstorage, out_shape))
787    }
788}
789
790#[allow(dead_code)]
791pub trait BitWiseOp {
792    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
793    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
794    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
795    fn bitwise_not(&self) -> Result<Tensor>;
796}
797
798impl BitWiseOp for Tensor {
799    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
800        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
801    }
802
803    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
804        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
805    }
806
807    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
808        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
809    }
810
811    fn bitwise_not(&self) -> Result<Tensor> {
812        self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
813    }
814}
815
816// ────────────────────────────── ArgSort / Sort ────────────────────────────────
817
818#[allow(unused)]
819/// Configuration for an **argsort** (returns indices) operation.
820struct ArgSort {
821    axis: usize,
822}
823
824#[allow(unused)]
825/// Configuration for a **sort** (returns re‑ordered values) operation.
826struct Sort {
827    axis: usize,
828}
829
830impl CustomOp1 for ArgSort {
831    fn name(&self) -> &'static str {
832        "argsort"
833    }
834
835    // -------- CPU ------------------------------------------------------------
836    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
837        candle_core::bail!("ArgSort is not implemented for the CPU backend");
838    }
839
840    // -------- CUDA -----------------------------------------------------------
841    #[cfg(feature = "cuda")]
842    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
843        candle_core::bail!("ArgSort is not implemented for the CUDA backend");
844    }
845
846    // -------- Metal ----------------------------------------------------------
847    #[cfg(feature = "metal")]
848    fn metal_fwd(
849        &self,
850        s1: &candle_core::MetalStorage,
851        l1: &Layout,
852    ) -> Result<(candle_core::MetalStorage, Shape)> {
853        // Require contiguous input (same as other metal ops in this file)
854        if !l1.is_contiguous() {
855            candle_core::bail!("Input tensor s1 must be contiguous");
856        }
857
858        // Create a command‑buffer and label it for easy debugging in Xcode’s GPU frame‑capture
859        let command_buffer = s1.device().command_buffer()?;
860        command_buffer.set_label("argsort");
861
862        let device = s1.device();
863        let out_shape = l1.shape().clone();
864        let elem_count = out_shape.elem_count();
865
866        // Output buffer holds the sorted indices → always `U32`
867        let output = device.new_buffer(elem_count, candle_core::DType::U32, "argsort")?;
868
869        // ------------------------------------------------------------------
870        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
871        // ------------------------------------------------------------------
872        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
873
874        let dims = l1.dims();
875        let size_sorted_axis = dims[self.axis];
876        let n_rows = l1.shape().elem_count() / size_sorted_axis;
877
878        // Replicate the kernel’s internal block sizing to derive `n_blocks`
879        let tn = 4usize;
880        let mut bn = match size_sorted_axis.div_ceil(tn) {
881            v if v > 256 => 512,
882            v if v > 128 => 256,
883            v if v > 64 => 128,
884            v if v > 32 => 64,
885            _ => 32,
886        };
887        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
888            bn = 256;
889        }
890        let n_per_block = bn * tn;
891        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
892
893        // Borrow the buffers for this launch
894        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
895
896        // ------------------------------------------------------------------
897        // Build the unified SortArgs payload
898        // ------------------------------------------------------------------
899        let sort_args = crate::metal_kernels::SortArgs {
900            axis: self.axis,
901            shape: l1.dims(),
902            strides: l1.stride(),
903            out_shape: l1.dims(), // same as input for argsort
904            out_strides: l1.stride(),
905            in_contiguous: l1.is_contiguous(),
906            in_ty: s1.dtype(),
907            out_ty: candle_core::DType::U32,
908            src: s1.buffer(),
909            src_offset: l1.start_offset(), // element offset
910            dst: &output,
911            bn,
912            tn,
913            n_blocks,
914        };
915
916        // Launch the Metal kernel via the new API
917        crate::metal_kernels::call_argsort(
918            device.device(), // &metal::Device
919            &command_buffer, // impl EncoderProvider
920            &crate::metal_kernels::Kernels::new(),
921            &sort_args,
922            &scratch,
923        )
924        .map_err(candle_core::Error::wrap)?;
925
926        // Wrap and return as a new MetalStorage
927        let newstorage = candle_core::MetalStorage::new(
928            output,
929            device.clone(),
930            elem_count,
931            candle_core::DType::U32,
932        );
933        Ok((newstorage, out_shape))
934    }
935}
936
937impl CustomOp1 for Sort {
938    fn name(&self) -> &'static str {
939        "sort"
940    }
941
942    // -------- CPU ------------------------------------------------------------
943    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
944        candle_core::bail!("Sort is not implemented for the CPU backend");
945    }
946
947    // -------- CUDA -----------------------------------------------------------
948    #[cfg(feature = "cuda")]
949    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
950        candle_core::bail!("Sort is not implemented for the CUDA backend");
951    }
952
953    // -------- Metal ----------------------------------------------------------
954    #[cfg(feature = "metal")]
955    fn metal_fwd(
956        &self,
957        s1: &candle_core::MetalStorage,
958        l1: &Layout,
959    ) -> Result<(candle_core::MetalStorage, Shape)> {
960        // Require contiguous input (same as other metal ops in this file)
961        if !l1.is_contiguous() {
962            candle_core::bail!("Input tensor s1 must be contiguous");
963        }
964
965        // Create a command‑buffer and label it for easy debugging in Xcode’s GPU frame‑capture
966        let command_buffer = s1.device().command_buffer()?;
967        command_buffer.set_label("sort");
968
969        let device = s1.device();
970        let out_shape = l1.shape().clone();
971        let elem_count = out_shape.elem_count();
972
973        // Output buffer keeps the same dtype as the input (these are the reordered values)
974        let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
975
976        // ------------------------------------------------------------------
977        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
978        // ------------------------------------------------------------------
979        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
980
981        let dims = l1.dims();
982        let size_sorted_axis = dims[self.axis];
983        let n_rows = l1.shape().elem_count() / size_sorted_axis;
984
985        // Replicate the kernel’s internal block sizing to derive `n_blocks`
986        let tn = 4usize;
987        let mut bn = match size_sorted_axis.div_ceil(tn) {
988            v if v > 256 => 512,
989            v if v > 128 => 256,
990            v if v > 64 => 128,
991            v if v > 32 => 64,
992            _ => 32,
993        };
994        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
995            bn = 256;
996        }
997        let n_per_block = bn * tn;
998        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
999
1000        // Borrow the buffers for this launch
1001        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
1002
1003        // ------------------------------------------------------------------
1004        // Build the unified SortArgs payload
1005        // ------------------------------------------------------------------
1006        let sort_args = crate::metal_kernels::SortArgs {
1007            axis: self.axis,
1008            shape: l1.dims(),
1009            strides: l1.stride(),
1010            out_shape: l1.dims(), // same shape for value sort
1011            out_strides: l1.stride(),
1012            in_contiguous: l1.is_contiguous(),
1013            in_ty: s1.dtype(),
1014            out_ty: s1.dtype(),
1015            src: s1.buffer(),
1016            src_offset: l1.start_offset(), // element offset
1017            dst: &output,
1018            bn,
1019            tn,
1020            n_blocks,
1021        };
1022
1023        // Launch the Metal kernel via the new API
1024        crate::metal_kernels::call_sort(
1025            device.device(), // &metal::Device
1026            &command_buffer, // impl EncoderProvider
1027            &crate::metal_kernels::Kernels::new(),
1028            &sort_args,
1029            &scratch,
1030        )
1031        .map_err(candle_core::Error::wrap)?;
1032
1033        // Wrap and return as a new MetalStorage
1034        let newstorage =
1035            candle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1036        Ok((newstorage, out_shape))
1037    }
1038}
1039
1040/// Extension trait adding `argsort` / `sort` convenience calls on `Tensor`.
1041pub trait SortOp {
1042    /// Returns the indices that would (ascending) sort the tensor along `axis`.
1043    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1044    /// Returns the tensor's values (ascending) sorted along `axis`.
1045    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1046}
1047
1048impl SortOp for Tensor {
1049    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1050        if self.device().is_cpu() || self.device().is_cuda() {
1051            return self.arg_sort_last_dim(true);
1052        }
1053        self.apply_op1_no_bwd(&ArgSort {
1054            axis: axis.to_index(self.shape(), "argsort")?,
1055        })
1056    }
1057
1058    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1059        if self.device().is_cpu() || self.device().is_cuda() {
1060            return Ok(self.sort_last_dim(true)?.0);
1061        }
1062        self.apply_op1_no_bwd(&Sort {
1063            axis: axis.to_index(self.shape(), "sort")?,
1064        })
1065    }
1066}
1067
1068struct NonZero;
1069
1070impl NonZero {
1071    // Sequential version
1072    fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1073        let n = layout.dims().len();
1074        let mut result = Vec::new();
1075        let mut indices = vec![0u32; n];
1076        for (i, v) in vs.iter().enumerate() {
1077            if !v.is_zero() {
1078                let mut idx = i;
1079                for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1080                    let d = idx % dim;
1081                    indices[dim_index] = u32::try_from(d).unwrap();
1082                    idx /= dim;
1083                }
1084                result.extend_from_slice(&indices);
1085            }
1086        }
1087        result
1088    }
1089}
1090
1091#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1092mod cuda_ops_cccl2 {
1093    use super::*;
1094
1095    pub(super) fn count_nonzero_cuda(
1096        dtype: candle_core::DType,
1097        d_in: *const c_void,
1098        n: u32,
1099        stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1100    ) -> u32 {
1101        unsafe {
1102            match dtype {
1103                candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1104                candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1105                candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1106                candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1107                candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1108                candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1109                candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1110                candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1111                candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1112                _ => unreachable!(),
1113            }
1114        }
1115    }
1116
1117    #[allow(clippy::too_many_arguments)]
1118    pub(super) fn nonzero_cuda(
1119        dtype: candle_core::DType,
1120        d_in: *const c_void,
1121        n: u32,
1122        num_nonzero: u32,
1123        dims: *const c_void,
1124        num_dims: u32,
1125        d_out: *mut c_void,
1126        stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1127    ) {
1128        unsafe {
1129            match dtype {
1130                candle_core::DType::U8 => {
1131                    ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1132                }
1133                candle_core::DType::U32 => {
1134                    ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1135                }
1136                candle_core::DType::I64 => {
1137                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1138                }
1139                candle_core::DType::I32 => {
1140                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1141                }
1142                candle_core::DType::I16 => {
1143                    ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1144                }
1145                candle_core::DType::BF16 => {
1146                    ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1147                }
1148                candle_core::DType::F16 => {
1149                    ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1150                }
1151                candle_core::DType::F32 => {
1152                    ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1153                }
1154                candle_core::DType::F64 => {
1155                    ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1156                }
1157                _ => unreachable!(),
1158            }
1159        }
1160    }
1161}
1162
1163#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1164mod cuda_ops_cccl3 {
1165    use super::*;
1166
1167    pub(super) fn count_nonzero_cuda(
1168        dtype: candle_core::DType,
1169        d_in: *const c_void,
1170        n: u32,
1171        stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1172    ) -> u32 {
1173        unsafe {
1174            match dtype {
1175                candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1176                candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1177                candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1178                candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1179                candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1180                candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1181                candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1182                candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1183                candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1184                _ => unreachable!(),
1185            }
1186        }
1187    }
1188
1189    #[allow(clippy::too_many_arguments)]
1190    pub(super) fn nonzero_cuda(
1191        dtype: candle_core::DType,
1192        d_in: *const c_void,
1193        n: u32,
1194        num_nonzero: u32,
1195        dims: *const c_void,
1196        num_dims: u32,
1197        d_out: *mut c_void,
1198        stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1199    ) {
1200        unsafe {
1201            match dtype {
1202                candle_core::DType::U8 => {
1203                    ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1204                }
1205                candle_core::DType::U32 => {
1206                    ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1207                }
1208                candle_core::DType::I64 => {
1209                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1210                }
1211                candle_core::DType::I32 => {
1212                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1213                }
1214                candle_core::DType::I16 => {
1215                    ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1216                }
1217                candle_core::DType::BF16 => {
1218                    ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1219                }
1220                candle_core::DType::F16 => {
1221                    ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1222                }
1223                candle_core::DType::F32 => {
1224                    ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1225                }
1226                candle_core::DType::F64 => {
1227                    ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1228                }
1229                _ => unreachable!(),
1230            }
1231        }
1232    }
1233}
1234
1235#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1236use cuda_ops_cccl2::{count_nonzero_cuda, nonzero_cuda};
1237#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1238use cuda_ops_cccl3::{count_nonzero_cuda, nonzero_cuda};
1239
1240impl CustomOp1 for NonZero {
1241    fn name(&self) -> &'static str {
1242        "nonzero"
1243    }
1244
1245    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1246        if !layout.is_contiguous() {
1247            return Err(Error::RequiresContiguous { op: "nonzero" });
1248        }
1249        let result = match storage {
1250            candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1251            candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1252            candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1253            candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1254            candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1255            candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1256            candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1257            candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1258            candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1259            _ => unreachable!(),
1260        };
1261        let index_len = layout.dims().len();
1262        let result_len = result.len() / index_len;
1263        let result = CpuStorage::U32(result);
1264        let shape = Shape::from_dims(&[result_len, index_len]);
1265        Ok((result, shape))
1266    }
1267
1268    #[cfg(feature = "cuda")]
1269    fn cuda_fwd(
1270        &self,
1271        storage: &candle_core::CudaStorage,
1272        layout: &Layout,
1273    ) -> Result<(candle_core::CudaStorage, Shape)> {
1274        if !layout.is_contiguous() {
1275            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1276        }
1277        let dev = storage.device().clone();
1278        let (d_in, _d_in_guard) = match storage.dtype() {
1279            candle_core::DType::U8 => {
1280                let slice = storage.as_cuda_slice::<u8>()?;
1281                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1282                (d_in as *const std::ffi::c_void, d_in_guard)
1283            }
1284            candle_core::DType::U32 => {
1285                let slice = storage.as_cuda_slice::<u32>()?;
1286                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1287                (d_in as *const std::ffi::c_void, d_in_guard)
1288            }
1289            candle_core::DType::I32 => {
1290                let slice = storage.as_cuda_slice::<i32>()?;
1291                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1292                (d_in as *const std::ffi::c_void, d_in_guard)
1293            }
1294            candle_core::DType::I16 => {
1295                let slice = storage.as_cuda_slice::<i16>()?;
1296                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1297                (d_in as *const std::ffi::c_void, d_in_guard)
1298            }
1299            candle_core::DType::I64 => {
1300                let slice = storage.as_cuda_slice::<i64>()?;
1301                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1302                (d_in as *const std::ffi::c_void, d_in_guard)
1303            }
1304            candle_core::DType::BF16 => {
1305                let slice = storage.as_cuda_slice::<half::bf16>()?;
1306                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1307                (d_in as *const std::ffi::c_void, d_in_guard)
1308            }
1309            candle_core::DType::F16 => {
1310                let slice = storage.as_cuda_slice::<half::f16>()?;
1311                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1312                (d_in as *const std::ffi::c_void, d_in_guard)
1313            }
1314            candle_core::DType::F32 => {
1315                let slice = storage.as_cuda_slice::<f32>()?;
1316                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1317                (d_in as *const std::ffi::c_void, d_in_guard)
1318            }
1319            candle_core::DType::F64 => {
1320                let slice = storage.as_cuda_slice::<f64>()?;
1321                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1322                (d_in as *const std::ffi::c_void, d_in_guard)
1323            }
1324            _ => unreachable!(),
1325        };
1326        let n = layout.shape().elem_count();
1327
1328        let num_nonzero = count_nonzero_cuda(
1329            storage.dtype(),
1330            d_in,
1331            u32::try_from(n)?,
1332            dev.cuda_stream().cu_stream(),
1333        );
1334        let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1335            .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1336        if num_nonzero != 0 {
1337            let (d_out, _d_out_guard) = d_out.device_ptr(d_out.stream());
1338            let dims = layout
1339                .dims()
1340                .iter()
1341                .map(|&x| u32::try_from(x).unwrap())
1342                .collect::<Vec<u32>>();
1343            let mut d_dims = unsafe { dev.alloc::<u32>(dims.len()) }?;
1344            dev.memcpy_htod(&dims, &mut d_dims)?;
1345            let (d_dims_ptr, _d_dims_guard) = d_dims.device_ptr(d_dims.stream());
1346            nonzero_cuda(
1347                storage.dtype(),
1348                d_in,
1349                u32::try_from(n)?,
1350                num_nonzero,
1351                d_dims_ptr as *const c_void,
1352                u32::try_from(layout.dims().len())?,
1353                d_out as *mut c_void,
1354                dev.cuda_stream().cu_stream(),
1355            );
1356        }
1357        let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1358        let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
1359        Ok((dst, shape))
1360    }
1361}
1362
1363pub trait NonZeroOp {
1364    fn nonzero(&self) -> Result<Tensor>;
1365}
1366
1367impl NonZeroOp for Tensor {
1368    #[cfg(feature = "metal")]
1369    fn nonzero(&self) -> Result<Tensor> {
1370        if !self.is_contiguous() {
1371            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1372        }
1373        let original_device = self.device();
1374        self.to_device(&candle_core::Device::Cpu)?
1375            .apply_op1_no_bwd(&NonZero)?
1376            .to_device(original_device)
1377    }
1378
1379    #[cfg(not(feature = "metal"))]
1380    fn nonzero(&self) -> Result<Tensor> {
1381        if !self.is_contiguous() {
1382            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1383        }
1384        self.apply_op1_no_bwd(&NonZero)
1385    }
1386}
1387
1388struct CumSum {
1389    inclusive: bool,
1390    reverse: bool,
1391    axis: usize,
1392}
1393
1394impl CustomOp1 for CumSum {
1395    fn name(&self) -> &'static str {
1396        "cumsum"
1397    }
1398
1399    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1400        use std::ops::Add;
1401        if !l1.is_contiguous() {
1402            candle_core::bail!("Input tensor s1 must be contiguous");
1403        }
1404        let dims = l1.dims();
1405        let axis = self.axis;
1406        let axis_len = dims[axis];
1407        let (start, end) = l1
1408            .contiguous_offsets()
1409            .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1410
1411        // helper to execute scan for a slice of T
1412        macro_rules! scan_block {
1413            ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1414                let vs: &[$ty] = $vt;
1415                let input = &vs[start..end];
1416                let count = input.len() / axis_len;
1417                let mut result = Vec::<$ty>::with_capacity(input.len());
1418                if !self.reverse {
1419                    if self.inclusive {
1420                        for block in 0..count {
1421                            let base = block * axis_len;
1422                            let mut sum = input[base];
1423                            result.push(sum);
1424                            for j in 1..axis_len {
1425                                sum = sum.$add(input[base + j]);
1426                                result.push(sum);
1427                            }
1428                        }
1429                    } else {
1430                        let init: $ty = $init;
1431                        for block in 0..count {
1432                            let base = block * axis_len;
1433                            let mut sum = init;
1434                            for j in 0..axis_len {
1435                                result.push(sum);
1436                                sum = sum.$add(input[base + j]);
1437                            }
1438                        }
1439                    }
1440                } else {
1441                    if self.inclusive {
1442                        for block in 0..count {
1443                            let base = block * axis_len;
1444                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1445                            let mut sum = input[base + axis_len - 1];
1446                            temp.push(sum);
1447                            for k in 1..axis_len {
1448                                let idx = axis_len - 1 - k;
1449                                sum = sum.$add(input[base + idx]);
1450                                temp.push(sum);
1451                            }
1452                            temp.reverse();
1453                            result.extend(temp);
1454                        }
1455                    } else {
1456                        let init: $ty = $init;
1457                        for block in 0..count {
1458                            let base = block * axis_len;
1459                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1460                            let mut sum = init;
1461                            for k in 0..axis_len {
1462                                let idx = axis_len - 1 - k;
1463                                temp.push(sum);
1464                                sum = sum.$add(input[base + idx]);
1465                            }
1466                            temp.reverse();
1467                            result.extend(temp);
1468                        }
1469                    }
1470                }
1471                result
1472            }};
1473        }
1474        match s1 {
1475            CpuStorage::U8(vs) => {
1476                let result = scan_block!(vs, u8, wrapping_add, 0u8);
1477                Ok((CpuStorage::U8(result), l1.shape().clone()))
1478            }
1479            CpuStorage::I16(vs) => {
1480                let result = scan_block!(vs, i16, add, 0i16);
1481                Ok((CpuStorage::I16(result), l1.shape().clone()))
1482            }
1483            CpuStorage::U32(vs) => {
1484                let result = scan_block!(vs, u32, wrapping_add, 0u32);
1485                Ok((CpuStorage::U32(result), l1.shape().clone()))
1486            }
1487            CpuStorage::I32(vs) => {
1488                let result = scan_block!(vs, i32, add, 0i32);
1489                Ok((CpuStorage::I32(result), l1.shape().clone()))
1490            }
1491            CpuStorage::I64(vs) => {
1492                let result = scan_block!(vs, i64, add, 0i64);
1493                Ok((CpuStorage::I64(result), l1.shape().clone()))
1494            }
1495            CpuStorage::F32(vs) => {
1496                let result = scan_block!(vs, f32, add, 0.0f32);
1497                Ok((CpuStorage::F32(result), l1.shape().clone()))
1498            }
1499            CpuStorage::F64(vs) => {
1500                let result = scan_block!(vs, f64, add, 0.0f64);
1501                Ok((CpuStorage::F64(result), l1.shape().clone()))
1502            }
1503            _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1504        }
1505    }
1506
1507    #[cfg(feature = "cuda")]
1508    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1509        todo!()
1510    }
1511
1512    #[cfg(feature = "metal")]
1513    fn metal_fwd(
1514        &self,
1515        s1: &candle_core::MetalStorage,
1516        l1: &Layout,
1517    ) -> Result<(candle_core::MetalStorage, Shape)> {
1518        use crate::metal_kernels::ScanType;
1519
1520        let command_buffer = s1.device().command_buffer()?;
1521        command_buffer.set_label("cumsum");
1522
1523        let device = s1.device();
1524
1525        let out_shape = l1.shape().clone();
1526
1527        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1528
1529        crate::metal_kernels::call_scan(
1530            device.device(),
1531            &command_buffer,
1532            &crate::metal_kernels::Kernels::new(),
1533            s1.dtype(),
1534            ScanType::Sum,
1535            s1.buffer(),
1536            l1.start_offset() * s1.dtype().size_in_bytes(),
1537            self.axis,
1538            l1.dims(),
1539            l1.stride(),
1540            self.reverse,
1541            self.inclusive,
1542            &output,
1543        )
1544        .map_err(candle_core::Error::wrap)?;
1545
1546        let newstorage = candle_core::MetalStorage::new(
1547            output,
1548            device.clone(),
1549            out_shape.elem_count(),
1550            s1.dtype(),
1551        );
1552        Ok((newstorage, out_shape))
1553    }
1554}
1555
1556#[allow(dead_code)]
1557pub trait CumSumOp {
1558    /// inclusive = false, reverse = false
1559    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1560
1561    fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1562        -> Result<Tensor>;
1563}
1564
1565impl CumSumOp for Tensor {
1566    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1567        self.fast_cumsum_config(axis, false, false)
1568    }
1569
1570    fn fast_cumsum_config<D: Dim>(
1571        &self,
1572        axis: D,
1573        inclusive: bool,
1574        reverse: bool,
1575    ) -> Result<Tensor> {
1576        self.apply_op1_no_bwd(&CumSum {
1577            inclusive,
1578            reverse,
1579            axis: axis.to_index(self.shape(), "cumsum")?,
1580        })
1581    }
1582}
1583
1584mod tests {
1585    #[test]
1586    fn test_cumsum_exclusive_forward_cpu() {
1587        use crate::utils::ops::CumSumOp;
1588        use candle_core::Tensor;
1589        let device = candle_core::Device::Cpu;
1590        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1591        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1592        assert_eq!(b, [0, 1, 3, 6]);
1593    }
1594
1595    #[test]
1596    fn test_cumsum_inclusive_forward_cpu() {
1597        use crate::utils::ops::CumSumOp;
1598        use candle_core::Tensor;
1599        let device = candle_core::Device::Cpu;
1600        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1601        let b = a
1602            .fast_cumsum_config(0, true, false)
1603            .unwrap()
1604            .to_vec1::<i64>()
1605            .unwrap();
1606        assert_eq!(b, [1, 3, 6, 10]);
1607    }
1608
1609    #[test]
1610    fn test_cumsum_exclusive_reverse_cpu() {
1611        use crate::utils::ops::CumSumOp;
1612        use candle_core::Tensor;
1613        let device = candle_core::Device::Cpu;
1614        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1615        let b = a
1616            .fast_cumsum_config(0, false, true)
1617            .unwrap()
1618            .to_vec1::<i64>()
1619            .unwrap();
1620        assert_eq!(b, [9, 7, 4, 0]);
1621    }
1622
1623    #[test]
1624    fn test_cumsum_inclusive_reverse_cpu() {
1625        use crate::utils::ops::CumSumOp;
1626        use candle_core::Tensor;
1627        let device = candle_core::Device::Cpu;
1628        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1629        let b = a
1630            .fast_cumsum_config(0, true, true)
1631            .unwrap()
1632            .to_vec1::<i64>()
1633            .unwrap();
1634        assert_eq!(b, [10, 9, 7, 4]);
1635    }
1636
1637    #[cfg(feature = "metal")]
1638    #[test]
1639    fn test_cumsum_exclusive_forward_metal() {
1640        use crate::utils::ops::CumSumOp;
1641        use candle_core::Tensor;
1642        let device = candle_core::Device::new_metal(0).unwrap();
1643        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1644        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1645        assert_eq!(b, [0, 1, 3, 6]);
1646    }
1647
1648    #[cfg(feature = "metal")]
1649    #[test]
1650    fn test_cumsum_inclusive_forward_metal() {
1651        use crate::utils::ops::CumSumOp;
1652        use candle_core::Tensor;
1653        let device = candle_core::Device::new_metal(0).unwrap();
1654        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1655        let b = a
1656            .fast_cumsum_config(0, true, false)
1657            .unwrap()
1658            .to_vec1::<i64>()
1659            .unwrap();
1660        assert_eq!(b, [1, 3, 6, 10]);
1661    }
1662
1663    #[cfg(feature = "metal")]
1664    #[test]
1665    fn test_cumsum_exclusive_reverse_metal() {
1666        use crate::utils::ops::CumSumOp;
1667        use candle_core::Tensor;
1668        let device = candle_core::Device::new_metal(0).unwrap();
1669        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1670        let b = a
1671            .fast_cumsum_config(0, false, true)
1672            .unwrap()
1673            .to_vec1::<i64>()
1674            .unwrap();
1675        assert_eq!(b, [9, 7, 4, 0]);
1676    }
1677
1678    #[cfg(feature = "metal")]
1679    #[test]
1680    fn test_cumsum_inclusive_reverse_metal() {
1681        use crate::utils::ops::CumSumOp;
1682        use candle_core::Tensor;
1683        let device = candle_core::Device::new_metal(0).unwrap();
1684        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1685        let b = a
1686            .fast_cumsum_config(0, true, true)
1687            .unwrap()
1688            .to_vec1::<i64>()
1689            .unwrap();
1690        assert_eq!(b, [10, 9, 7, 4]);
1691    }
1692
1693    #[test]
1694    fn test_nonzero_cpu() {
1695        use crate::utils::ops::NonZeroOp;
1696        use candle_core::Tensor;
1697        let device = candle_core::Device::Cpu;
1698        let a = Tensor::from_vec(
1699            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1700            &[2, 4],
1701            &device,
1702        )
1703        .unwrap();
1704        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1705        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1706    }
1707
1708    #[cfg(feature = "cuda")]
1709    #[test]
1710    fn test_nonzero_cuda() {
1711        use crate::utils::ops::NonZeroOp;
1712        use candle_core::Tensor;
1713        let device = candle_core::Device::new_cuda(0).unwrap();
1714        let a = Tensor::from_vec(
1715            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1716            &[2, 4],
1717            &device,
1718        )
1719        .unwrap();
1720        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1721        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1722    }
1723
1724    #[test]
1725    fn test_bitwise_and_cpu() {
1726        use crate::utils::ops::BitWiseOp;
1727        use candle_core::Tensor;
1728        let device = candle_core::Device::Cpu;
1729        let a =
1730            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1731        let b =
1732            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1733        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1734        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
1735    }
1736
1737    #[cfg(feature = "cuda")]
1738    #[test]
1739    fn test_bitwise_and_cuda() {
1740        use crate::utils::ops::BitWiseOp;
1741        use candle_core::Tensor;
1742        let device = candle_core::Device::new_cuda(0).unwrap();
1743        let a =
1744            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1745        let b =
1746            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
1747        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1748        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
1749    }
1750
1751    #[test]
1752    fn test_bitwise_or_cpu() {
1753        use crate::utils::ops::BitWiseOp;
1754        use candle_core::Tensor;
1755        let device = candle_core::Device::Cpu;
1756        let a =
1757            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1758        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1759        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1760        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1761    }
1762
1763    #[cfg(feature = "cuda")]
1764    #[test]
1765    fn test_bitwise_or_cuda() {
1766        use crate::utils::ops::BitWiseOp;
1767        use candle_core::Tensor;
1768        let device = candle_core::Device::new_cuda(0).unwrap();
1769        let a =
1770            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1771        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1772        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1773        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1774    }
1775
1776    #[test]
1777    fn test_bitwise_xor_cpu() {
1778        use crate::utils::ops::BitWiseOp;
1779        use candle_core::Tensor;
1780        let device = candle_core::Device::Cpu;
1781        let a =
1782            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1783        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1784        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1785        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1786    }
1787
1788    #[cfg(feature = "cuda")]
1789    #[test]
1790    fn test_bitwise_xor_cuda() {
1791        use crate::utils::ops::BitWiseOp;
1792        use candle_core::Tensor;
1793        let device = candle_core::Device::new_cuda(0).unwrap();
1794        let a =
1795            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1796        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1797        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1798        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1799    }
1800
1801    #[test]
1802    fn test_nonzero_and() {
1803        use crate::utils::ops::{BitWiseOp, NonZeroOp};
1804        use candle_core::{Device, Tensor};
1805
1806        let input1 = Tensor::from_vec(
1807            vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1808            (10,),
1809            &Device::Cpu,
1810        )
1811        .unwrap();
1812        let input2 = Tensor::from_vec(
1813            vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1814            (10,),
1815            &Device::Cpu,
1816        )
1817        .unwrap();
1818        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1819
1820        let lt = input.lt(0.0).unwrap();
1821        let gt = input.gt(-10.0).unwrap();
1822        let res = lt
1823            .bitwise_and(&gt)
1824            .unwrap()
1825            .nonzero()
1826            .unwrap()
1827            .to_vec2::<u32>()
1828            .unwrap();
1829
1830        assert_eq!(
1831            res,
1832            [
1833                [0, 3],
1834                [0, 4],
1835                [0, 5],
1836                [0, 6],
1837                [1, 0],
1838                [1, 3],
1839                [1, 5],
1840                [1, 6]
1841            ]
1842        );
1843    }
1844
1845    #[cfg(feature = "cuda")]
1846    #[test]
1847    fn nonzero_and_cuda() {
1848        use crate::utils::ops::{BitWiseOp, NonZeroOp};
1849        use candle_core::{Device, Tensor};
1850
1851        let device = Device::new_cuda(0).unwrap();
1852        let input1 =
1853            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1854        let input2 =
1855            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1856        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1857
1858        let lt = input.lt(0.0).unwrap();
1859        let gt = input.gt(-10.0).unwrap();
1860        let res = lt
1861            .bitwise_and(&gt)
1862            .unwrap()
1863            .nonzero()
1864            .unwrap()
1865            .to_vec2::<u32>()
1866            .unwrap();
1867
1868        assert_eq!(
1869            res,
1870            [
1871                [0, 3],
1872                [0, 4],
1873                [0, 5],
1874                [0, 6],
1875                [1, 0],
1876                [1, 3],
1877                [1, 5],
1878                [1, 6]
1879            ]
1880        );
1881    }
1882
1883    #[test]
1884    fn test_bitpack_8bit_cpu() {
1885        use crate::HqqBits;
1886        use candle_core::{Device, Tensor};
1887        let bits = HqqBits::Eight;
1888        let device = Device::Cpu;
1889        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1890        let c = bits.bitpack_type()(wq.clone())
1891            .unwrap()
1892            .to_vec2::<u8>()
1893            .unwrap();
1894        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1895    }
1896
1897    #[cfg(feature = "cuda")]
1898    #[test]
1899    fn test_bitpack_8bit_cuda() {
1900        use crate::HqqBits;
1901        use candle_core::DType;
1902        use candle_core::{Device, Tensor};
1903        let bits = HqqBits::Eight;
1904        let device = Device::new_cuda(0).unwrap();
1905        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1906        let c = bits.bitpack_type()(wq.clone())
1907            .unwrap()
1908            .to_dtype(DType::U8)
1909            .unwrap()
1910            .to_vec2::<u8>()
1911            .unwrap();
1912        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1913    }
1914
1915    #[cfg(feature = "metal")]
1916    #[test]
1917    fn test_bitpack_8bit_metal() {
1918        use crate::HqqBits;
1919        use candle_core::{Device, Tensor};
1920        let bits = HqqBits::Eight;
1921        let device = Device::new_metal(0).unwrap();
1922        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1923        let c = bits.bitpack_type()(wq.clone())
1924            .unwrap()
1925            .to_vec2::<u8>()
1926            .unwrap();
1927        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1928    }
1929
1930    #[test]
1931    fn test_bitpack_4bit() {
1932        use crate::HqqBits;
1933        use candle_core::{Device, Tensor};
1934        let bits = HqqBits::Four;
1935        let device = Device::Cpu;
1936        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1937        let c = bits.bitpack_type()(wq.clone())
1938            .unwrap()
1939            .to_vec2::<u8>()
1940            .unwrap();
1941        assert_eq!(c, [[19, 36]]);
1942    }
1943
1944    #[cfg(feature = "cuda")]
1945    #[test]
1946    fn test_bitpack_4bit_cuda() {
1947        use crate::HqqBits;
1948        use candle_core::{Device, Tensor};
1949        let bits = HqqBits::Four;
1950        let device = Device::new_cuda(0).unwrap();
1951        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1952        let c = bits.bitpack_type()(wq.clone())
1953            .unwrap()
1954            .to_vec2::<u8>()
1955            .unwrap();
1956        assert_eq!(c, [[19, 36]]);
1957    }
1958
1959    #[cfg(feature = "metal")]
1960    #[test]
1961    fn test_bitpack_4bit_metal() {
1962        use crate::HqqBits;
1963        use candle_core::{Device, Tensor};
1964        let bits = HqqBits::Four;
1965        let device = Device::new_metal(0).unwrap();
1966        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1967        let c = bits.bitpack_type()(wq.clone())
1968            .unwrap()
1969            .to_vec2::<u8>()
1970            .unwrap();
1971        assert_eq!(c, [[19, 36]]);
1972    }
1973    // ─────────────────────────────── Sort / ArgSort ────────────────────────────────
1974    #[cfg(feature = "metal")]
1975    #[test]
1976    fn test_sort_and_argsort_vector_metal() {
1977        use crate::utils::ops::SortOp;
1978        use candle_core::Tensor;
1979
1980        let device = candle_core::Device::new_metal(0).unwrap();
1981        let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
1982
1983        // sort (ascending)
1984        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1985        assert_eq!(sorted, [1, 2, 3, 4]);
1986
1987        // argsort (ascending indices)
1988        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1989        assert_eq!(idx, [1, 3, 0, 2]);
1990    }
1991
1992    #[cfg(feature = "metal")]
1993    #[test]
1994    fn test_sort_and_argsort_matrix_axis1_metal() {
1995        use crate::utils::ops::SortOp;
1996        use candle_core::Tensor;
1997
1998        let device = candle_core::Device::new_metal(0).unwrap();
1999        // 2 × 3 matrix:
2000        // [[3, 1, 2],
2001        //  [0, 4, 5]]
2002        let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
2003
2004        // Sort along axis=1 (second dimension)
2005        let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
2006        assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
2007
2008        // ArgSort indices along axis=1
2009        let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
2010        assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
2011    }
2012
2013    // ─────────────────────────────── 2 048-element vector ────────────────────────────────
2014    #[cfg(feature = "metal")]
2015    #[test]
2016    fn test_sort_and_argsort_vector_2048_metal() {
2017        use crate::utils::ops::SortOp;
2018        use candle_core::Tensor;
2019
2020        const N: usize = 4096;
2021
2022        let device = candle_core::Device::new_metal(0).expect("Metal device");
2023
2024        // Create a descending vector [4095, 4094, …, 0]
2025        let vals: Vec<i32> = (0..N as i32).rev().collect();
2026        let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
2027
2028        // ---- sort (ascending) ---------------------------------------------------------
2029        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
2030        let expected: Vec<i32> = (0..N as i32).collect();
2031        assert_eq!(sorted, expected);
2032
2033        // ---- argsort (indices that would sort) ---------------------------------------
2034        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
2035        // Because the input is reversed, the correct indices are likewise reversed
2036        for (i, &v) in idx.iter().enumerate() {
2037            assert_eq!(v as usize, N - 1 - i);
2038        }
2039    }
2040}