mistralrs_quant/utils/
ops.rs

1use candle_core::{
2    backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3    Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8    fmt::Display,
9    ops::{BitAnd, BitOr, BitXor, Not, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::ffi;
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19#[cfg(feature = "metal")]
20use crate::metal_kernels::SortScratchCache; // re‑export for clarity
21#[cfg(feature = "metal")]
22use std::sync::OnceLock;
23
24#[cfg(feature = "metal")]
25static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
26
27struct Leftshift(usize);
28
29impl Leftshift {
30    fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
31        let offset = T::from_f64(self.0 as f64);
32        vs.into_par_iter().map(|v| *v << offset).collect()
33    }
34}
35
36impl CustomOp1 for Leftshift {
37    fn name(&self) -> &'static str {
38        "left"
39    }
40
41    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
42        match s1 {
43            CpuStorage::U8(vs1) => {
44                let vs1 = match l1.contiguous_offsets() {
45                    Some((a, b)) => &vs1[a..b],
46                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
47                };
48                let result = self.leftshift(vs1);
49                let result = CpuStorage::U8(result);
50                Ok((result, l1.shape().clone()))
51            }
52            CpuStorage::I16(vs1) => {
53                let vs1 = match l1.contiguous_offsets() {
54                    Some((a, b)) => &vs1[a..b],
55                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
56                };
57                let result = self.leftshift(vs1);
58                let result = CpuStorage::I16(result);
59                Ok((result, l1.shape().clone()))
60            }
61            CpuStorage::U32(vs1) => {
62                let vs1 = match l1.contiguous_offsets() {
63                    Some((a, b)) => &vs1[a..b],
64                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
65                };
66                let result = self.leftshift(vs1);
67                let result = CpuStorage::U32(result);
68                Ok((result, l1.shape().clone()))
69            }
70            CpuStorage::I64(vs1) => {
71                let vs1 = match l1.contiguous_offsets() {
72                    Some((a, b)) => &vs1[a..b],
73                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
74                };
75                let result = self.leftshift(vs1);
76                let result = CpuStorage::I64(result);
77                Ok((result, l1.shape().clone()))
78            }
79            CpuStorage::I32(vs1) => {
80                let vs1 = match l1.contiguous_offsets() {
81                    Some((a, b)) => &vs1[a..b],
82                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
83                };
84                let result = self.leftshift(vs1);
85                let result = CpuStorage::I32(result);
86                Ok((result, l1.shape().clone()))
87            }
88            CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshifr")),
89            CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")),
90            CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")),
91            CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshifr")),
92            CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshifr")),
93        }
94    }
95    #[cfg(feature = "cuda")]
96    fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
97        if !l1.is_contiguous() {
98            candle_core::bail!("Input tensor s1 must be contiguous");
99        }
100        let dev = s1.device().clone();
101        let (d_in1_ptr, elem_count) = match s1.dtype() {
102            DType::U8 => {
103                let d_in1_ptr = *s1
104                    .as_cuda_slice::<u8>()?
105                    .slice(l1.start_offset()..)
106                    .device_ptr() as *const c_void;
107                let elem_count = l1.shape().elem_count();
108                (d_in1_ptr, elem_count)
109            }
110            DType::I16 => {
111                return Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshift"));
112            }
113            DType::U32 => {
114                return Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshift"));
115            }
116            DType::I64 => {
117                return Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshift"));
118            }
119            DType::I32 => {
120                let d_in1_ptr = *s1
121                    .as_cuda_slice::<i32>()?
122                    .slice(l1.start_offset()..)
123                    .device_ptr() as *const c_void;
124                let elem_count = l1.shape().elem_count();
125                (d_in1_ptr, elem_count)
126            }
127            DType::BF16 => {
128                return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift"));
129            }
130            DType::F16 => {
131                return Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshift"));
132            }
133            DType::F32 => {
134                return Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshift"));
135            }
136            DType::F64 => {
137                return Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshift"));
138            }
139            DType::F8E4M3 => {
140                return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshift"));
141            }
142        };
143        let dst = match s1.dtype() {
144            DType::U8 => {
145                let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
146                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
147                unsafe {
148                    ffi::leftshift_u8(
149                        d_in1_ptr,
150                        d_out_ptr,
151                        u32::try_from(elem_count)?,
152                        self.0 as i32,
153                    )
154                };
155                CudaStorage::wrap_cuda_slice(d_out, dev)
156            }
157            DType::I32 => {
158                let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
159                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
160                unsafe {
161                    ffi::leftshift_i32(
162                        d_in1_ptr,
163                        d_out_ptr,
164                        u32::try_from(elem_count)?,
165                        self.0 as i32,
166                    )
167                };
168                CudaStorage::wrap_cuda_slice(d_out, dev)
169            }
170            _ => unreachable!(),
171        };
172        Ok((dst, l1.shape().clone()))
173    }
174    #[cfg(feature = "metal")]
175    fn metal_fwd(
176        &self,
177        s1: &candle_core::MetalStorage,
178        l1: &Layout,
179    ) -> Result<(candle_core::MetalStorage, Shape)> {
180        if !l1.is_contiguous() {
181            candle_core::bail!("Input tensor s1 must be contiguous");
182        }
183
184        let command_buffer = s1.device().command_buffer()?;
185        command_buffer.set_label("bitwise-leftshift");
186
187        let device = s1.device();
188
189        let out_shape = l1.shape().clone();
190
191        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
192
193        crate::metal_kernels::call_bitwise_leftshift(
194            device.device(),
195            &command_buffer,
196            &crate::metal_kernels::Kernels::new(),
197            s1.dtype(),
198            s1.buffer(),
199            l1.start_offset(),
200            self.0 as u32,
201            out_shape.elem_count(),
202            &output,
203        )
204        .map_err(candle_core::Error::wrap)?;
205
206        let newstorage = candle_core::MetalStorage::new(
207            output,
208            device.clone(),
209            out_shape.elem_count(),
210            s1.dtype(),
211        );
212        Ok((newstorage, out_shape))
213    }
214}
215
216#[allow(dead_code)]
217pub trait LeftshiftOp {
218    fn leftshift(&self, n: usize) -> Result<Tensor>;
219}
220
221impl LeftshiftOp for Tensor {
222    fn leftshift(&self, n: usize) -> Result<Tensor> {
223        self.apply_op1_no_bwd(&Leftshift(n))
224    }
225}
226
227pub enum BitWiseBinaryOpEnum {
228    And,
229    Or,
230    Xor,
231}
232
233impl Display for BitWiseBinaryOpEnum {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        match self {
236            BitWiseBinaryOpEnum::And => write!(f, "And"),
237            BitWiseBinaryOpEnum::Or => write!(f, "Or"),
238            BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
239        }
240    }
241}
242
243pub enum BitWiseUnaryOpEnum {
244    Not,
245}
246
247impl Display for BitWiseUnaryOpEnum {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        match self {
250            BitWiseUnaryOpEnum::Not => write!(f, "Not"),
251        }
252    }
253}
254
255struct BitWise {
256    pub op: BitWiseBinaryOpEnum,
257}
258
259impl BitWise {
260    pub fn new(op: BitWiseBinaryOpEnum) -> Self {
261        Self { op }
262    }
263
264    fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
265        &self,
266        vs1: &[T],
267        vs2: &[T],
268    ) -> Vec<T> {
269        vs1.into_par_iter()
270            .zip_eq(vs2)
271            .map(|(v1, v2)| match self.op {
272                BitWiseBinaryOpEnum::And => *v1 & *v2,
273                BitWiseBinaryOpEnum::Or => *v1 | *v2,
274                BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
275            })
276            .collect()
277    }
278}
279
280impl CustomOp2 for BitWise {
281    fn name(&self) -> &'static str {
282        "bitwise"
283    }
284
285    fn cpu_fwd(
286        &self,
287        s1: &CpuStorage,
288        l1: &Layout,
289        s2: &CpuStorage,
290        l2: &Layout,
291    ) -> Result<(CpuStorage, Shape)> {
292        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
293            return Err(Error::ShapeMismatchBinaryOp {
294                lhs: l1.shape().clone(),
295                rhs: l2.shape().clone(),
296                op: "bitwise-op",
297            });
298        }
299        if s1.dtype() != s2.dtype() {
300            return Err(Error::DTypeMismatchBinaryOp {
301                lhs: s1.dtype(),
302                rhs: s2.dtype(),
303                op: "bitwise-op",
304            });
305        }
306        if !l1.is_contiguous() {
307            candle_core::bail!("Input tensor s1 must be contiguous");
308        }
309        if !l2.is_contiguous() {
310            candle_core::bail!("Input tensor s2 must be contiguous");
311        }
312
313        match s1 {
314            CpuStorage::U8(vs1) => {
315                let vs2 = s2.as_slice::<u8>().unwrap();
316                let vs1 = match l1.contiguous_offsets() {
317                    Some((a, b)) => &vs1[a..b],
318                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
319                };
320                let vs2 = match l2.contiguous_offsets() {
321                    Some((a, b)) => &vs2[a..b],
322                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
323                };
324                let result = self.bitwise(vs1, vs2);
325                let result = CpuStorage::U8(result);
326                Ok((result, l1.shape().clone()))
327            }
328            CpuStorage::U32(vs1) => {
329                let vs2 = s2.as_slice::<u32>().unwrap();
330                let vs1 = match l1.contiguous_offsets() {
331                    Some((a, b)) => &vs1[a..b],
332                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
333                };
334                let vs2 = match l2.contiguous_offsets() {
335                    Some((a, b)) => &vs2[a..b],
336                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
337                };
338                let result = self.bitwise(vs1, vs2);
339                let result = CpuStorage::U32(result);
340                Ok((result, l1.shape().clone()))
341            }
342            CpuStorage::I64(vs1) => {
343                let vs2 = s2.as_slice::<i64>().unwrap();
344                let vs1 = match l1.contiguous_offsets() {
345                    Some((a, b)) => &vs1[a..b],
346                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
347                };
348                let vs2 = match l2.contiguous_offsets() {
349                    Some((a, b)) => &vs2[a..b],
350                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
351                };
352                let result = self.bitwise(vs1, vs2);
353                let result = CpuStorage::I64(result);
354                Ok((result, l1.shape().clone()))
355            }
356            CpuStorage::I16(vs1) => {
357                let vs2 = s2.as_slice::<i16>().unwrap();
358                let vs1 = match l1.contiguous_offsets() {
359                    Some((a, b)) => &vs1[a..b],
360                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
361                };
362                let vs2 = match l2.contiguous_offsets() {
363                    Some((a, b)) => &vs2[a..b],
364                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
365                };
366                let result = self.bitwise(vs1, vs2);
367                let result = CpuStorage::I16(result);
368                Ok((result, l1.shape().clone()))
369            }
370            CpuStorage::I32(vs1) => {
371                let vs2 = s2.as_slice::<i32>().unwrap();
372                let vs1 = match l1.contiguous_offsets() {
373                    Some((a, b)) => &vs1[a..b],
374                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
375                };
376                let vs2 = match l2.contiguous_offsets() {
377                    Some((a, b)) => &vs2[a..b],
378                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
379                };
380                let result = self.bitwise(vs1, vs2);
381                let result = CpuStorage::I32(result);
382                Ok((result, l1.shape().clone()))
383            }
384            CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")),
385            CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")),
386            CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")),
387            CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")),
388            CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")),
389        }
390    }
391
392    #[cfg(feature = "cuda")]
393    fn cuda_fwd(
394        &self,
395        s1: &CudaStorage,
396        l1: &Layout,
397        s2: &CudaStorage,
398        l2: &Layout,
399    ) -> Result<(CudaStorage, Shape)> {
400        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
401            return Err(Error::ShapeMismatchBinaryOp {
402                lhs: l1.shape().clone(),
403                rhs: l2.shape().clone(),
404                op: "bitwise-op",
405            });
406        }
407        if s1.dtype() != s2.dtype() {
408            return Err(Error::DTypeMismatchBinaryOp {
409                lhs: s1.dtype(),
410                rhs: s2.dtype(),
411                op: "bitwise-op",
412            });
413        }
414        if !l1.is_contiguous() {
415            candle_core::bail!("Input tensor s1 must be contiguous");
416        }
417        if !l2.is_contiguous() {
418            candle_core::bail!("Input tensor s2 must be contiguous");
419        }
420
421        let dev = s1.device().clone();
422        let (d_in1_ptr, d_in2_ptr, elem_count) = match s1.dtype() {
423            DType::U8 => {
424                let d_in1_ptr = *s1
425                    .as_cuda_slice::<u8>()?
426                    .slice(l1.start_offset()..)
427                    .device_ptr() as *const c_void;
428                let d_in2_ptr = *s2
429                    .as_cuda_slice::<u8>()?
430                    .slice(l2.start_offset()..)
431                    .device_ptr() as *const c_void;
432                let elem_count = l1.shape().elem_count();
433                (d_in1_ptr, d_in2_ptr, elem_count)
434            }
435            DType::U32 => {
436                let d_in1_ptr = *s1
437                    .as_cuda_slice::<u32>()?
438                    .slice(l1.start_offset()..)
439                    .device_ptr() as *const c_void;
440                let d_in2_ptr = *s2
441                    .as_cuda_slice::<u32>()?
442                    .slice(l2.start_offset()..)
443                    .device_ptr() as *const c_void;
444                let elem_count = l1.shape().elem_count();
445                (d_in1_ptr, d_in2_ptr, elem_count)
446            }
447            DType::I64 => {
448                let d_in1_ptr = *s1
449                    .as_cuda_slice::<i64>()?
450                    .slice(l1.start_offset()..)
451                    .device_ptr() as *const c_void;
452                let d_in2_ptr = *s2
453                    .as_cuda_slice::<i64>()?
454                    .slice(l2.start_offset()..)
455                    .device_ptr() as *const c_void;
456                let elem_count = l1.shape().elem_count();
457                (d_in1_ptr, d_in2_ptr, elem_count)
458            }
459            DType::I32 => {
460                let d_in1_ptr = *s1
461                    .as_cuda_slice::<i32>()?
462                    .slice(l1.start_offset()..)
463                    .device_ptr() as *const c_void;
464                let d_in2_ptr = *s2
465                    .as_cuda_slice::<i32>()?
466                    .slice(l2.start_offset()..)
467                    .device_ptr() as *const c_void;
468                let elem_count = l1.shape().elem_count();
469                (d_in1_ptr, d_in2_ptr, elem_count)
470            }
471            DType::I16 => {
472                let d_in1_ptr = *s1
473                    .as_cuda_slice::<i16>()?
474                    .slice(l1.start_offset()..)
475                    .device_ptr() as *const c_void;
476                let d_in2_ptr = *s2
477                    .as_cuda_slice::<i16>()?
478                    .slice(l2.start_offset()..)
479                    .device_ptr() as *const c_void;
480                let elem_count = l1.shape().elem_count();
481                (d_in1_ptr, d_in2_ptr, elem_count)
482            }
483            DType::BF16 => {
484                return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise"));
485            }
486            DType::F16 => {
487                return Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise"));
488            }
489            DType::F32 => {
490                return Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise"));
491            }
492            DType::F64 => {
493                return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise"));
494            }
495            DType::F8E4M3 => {
496                return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise"));
497            }
498        };
499        let dst = match s1.dtype() {
500            DType::U8 => {
501                let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
502                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
503                unsafe {
504                    match self.op {
505                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
506                            d_in1_ptr,
507                            d_in2_ptr,
508                            d_out_ptr,
509                            u32::try_from(elem_count)?,
510                        ),
511                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
512                            d_in1_ptr,
513                            d_in2_ptr,
514                            d_out_ptr,
515                            u32::try_from(elem_count)?,
516                        ),
517                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
518                            d_in1_ptr,
519                            d_in2_ptr,
520                            d_out_ptr,
521                            u32::try_from(elem_count)?,
522                        ),
523                    }
524                };
525                CudaStorage::wrap_cuda_slice(d_out, dev)
526            }
527            DType::U32 => {
528                let d_out = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
529                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
530                unsafe {
531                    match self.op {
532                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
533                            d_in1_ptr,
534                            d_in2_ptr,
535                            d_out_ptr,
536                            u32::try_from(elem_count)?,
537                        ),
538                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
539                            d_in1_ptr,
540                            d_in2_ptr,
541                            d_out_ptr,
542                            u32::try_from(elem_count)?,
543                        ),
544                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
545                            d_in1_ptr,
546                            d_in2_ptr,
547                            d_out_ptr,
548                            u32::try_from(elem_count)?,
549                        ),
550                    }
551                };
552                CudaStorage::wrap_cuda_slice(d_out, dev)
553            }
554            DType::I64 => {
555                let d_out = unsafe { dev.alloc::<i64>(elem_count) }.w()?;
556                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
557                unsafe {
558                    match self.op {
559                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
560                            d_in1_ptr,
561                            d_in2_ptr,
562                            d_out_ptr,
563                            u32::try_from(elem_count)?,
564                        ),
565                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
566                            d_in1_ptr,
567                            d_in2_ptr,
568                            d_out_ptr,
569                            u32::try_from(elem_count)?,
570                        ),
571                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
572                            d_in1_ptr,
573                            d_in2_ptr,
574                            d_out_ptr,
575                            u32::try_from(elem_count)?,
576                        ),
577                    }
578                };
579                CudaStorage::wrap_cuda_slice(d_out, dev)
580            }
581            DType::I32 => {
582                let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
583                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
584                unsafe {
585                    match self.op {
586                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
587                            d_in1_ptr,
588                            d_in2_ptr,
589                            d_out_ptr,
590                            u32::try_from(elem_count)?,
591                        ),
592                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
593                            d_in1_ptr,
594                            d_in2_ptr,
595                            d_out_ptr,
596                            u32::try_from(elem_count)?,
597                        ),
598                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
599                            d_in1_ptr,
600                            d_in2_ptr,
601                            d_out_ptr,
602                            u32::try_from(elem_count)?,
603                        ),
604                    }
605                };
606                CudaStorage::wrap_cuda_slice(d_out, dev)
607            }
608            _ => unreachable!(),
609        };
610        Ok((dst, l1.shape().clone()))
611    }
612
613    #[cfg(feature = "metal")]
614    fn metal_fwd(
615        &self,
616        s1: &candle_core::MetalStorage,
617        l1: &Layout,
618        s2: &candle_core::MetalStorage,
619        l2: &Layout,
620    ) -> Result<(candle_core::MetalStorage, Shape)> {
621        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
622            return Err(Error::ShapeMismatchBinaryOp {
623                lhs: l1.shape().clone(),
624                rhs: l2.shape().clone(),
625                op: "bitwise-op",
626            });
627        }
628        if s1.dtype() != s2.dtype() {
629            return Err(Error::DTypeMismatchBinaryOp {
630                lhs: s1.dtype(),
631                rhs: s2.dtype(),
632                op: "bitwise-op",
633            });
634        }
635        if !l1.is_contiguous() {
636            candle_core::bail!("Input tensor s1 must be contiguous");
637        }
638        if !l2.is_contiguous() {
639            candle_core::bail!("Input tensor s2 must be contiguous");
640        }
641
642        let command_buffer = s1.device().command_buffer()?;
643        command_buffer.set_label("bitwise-op");
644
645        let device = s1.device();
646
647        let out_shape = l1.shape().clone();
648
649        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
650
651        match self.op {
652            BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
653                device.device(),
654                &command_buffer,
655                &crate::metal_kernels::Kernels::new(),
656                s1.dtype(),
657                s1.buffer(),
658                s2.buffer(),
659                l1.start_offset() * s1.dtype().size_in_bytes(),
660                l2.start_offset() * s2.dtype().size_in_bytes(),
661                out_shape.elem_count(),
662                &output,
663            )
664            .map_err(candle_core::Error::wrap)?,
665            BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
666                device.device(),
667                &command_buffer,
668                &crate::metal_kernels::Kernels::new(),
669                s1.dtype(),
670                s1.buffer(),
671                s2.buffer(),
672                l1.start_offset() * s1.dtype().size_in_bytes(),
673                l2.start_offset() * s2.dtype().size_in_bytes(),
674                out_shape.elem_count(),
675                &output,
676            )
677            .map_err(candle_core::Error::wrap)?,
678            BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
679                device.device(),
680                &command_buffer,
681                &crate::metal_kernels::Kernels::new(),
682                s1.dtype(),
683                s1.buffer(),
684                s2.buffer(),
685                l1.start_offset() * s1.dtype().size_in_bytes(),
686                l2.start_offset() * s2.dtype().size_in_bytes(),
687                out_shape.elem_count(),
688                &output,
689            )
690            .map_err(candle_core::Error::wrap)?,
691        }
692
693        let newstorage = candle_core::MetalStorage::new(
694            output,
695            device.clone(),
696            out_shape.elem_count(),
697            s1.dtype(),
698        );
699        Ok((newstorage, out_shape))
700    }
701}
702
703struct BitWiseUnary {
704    pub op: BitWiseUnaryOpEnum,
705}
706
707impl BitWiseUnary {
708    pub fn new(op: BitWiseUnaryOpEnum) -> Self {
709        Self { op }
710    }
711
712    fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
713        vs1.into_par_iter()
714            .map(|v1| match self.op {
715                BitWiseUnaryOpEnum::Not => !*v1,
716            })
717            .collect()
718    }
719}
720
721impl CustomOp1 for BitWiseUnary {
722    fn name(&self) -> &'static str {
723        "bitwise-unary"
724    }
725
726    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
727        if !l1.is_contiguous() {
728            candle_core::bail!("Input tensor s1 must be contiguous");
729        }
730
731        match s1 {
732            CpuStorage::U8(vs1) => {
733                let vs1 = match l1.contiguous_offsets() {
734                    Some((a, b)) => &vs1[a..b],
735                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
736                };
737                let result = self.bitwise(vs1);
738                let result = CpuStorage::U8(result);
739                Ok((result, l1.shape().clone()))
740            }
741            CpuStorage::U32(vs1) => {
742                let vs1 = match l1.contiguous_offsets() {
743                    Some((a, b)) => &vs1[a..b],
744                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
745                };
746                let result = self.bitwise(vs1);
747                let result = CpuStorage::U32(result);
748                Ok((result, l1.shape().clone()))
749            }
750            CpuStorage::I64(vs1) => {
751                let vs1 = match l1.contiguous_offsets() {
752                    Some((a, b)) => &vs1[a..b],
753                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
754                };
755                let result = self.bitwise(vs1);
756                let result = CpuStorage::I64(result);
757                Ok((result, l1.shape().clone()))
758            }
759            CpuStorage::I16(vs1) => {
760                let vs1 = match l1.contiguous_offsets() {
761                    Some((a, b)) => &vs1[a..b],
762                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
763                };
764                let result = self.bitwise(vs1);
765                let result = CpuStorage::I16(result);
766                Ok((result, l1.shape().clone()))
767            }
768            CpuStorage::I32(vs1) => {
769                let vs1 = match l1.contiguous_offsets() {
770                    Some((a, b)) => &vs1[a..b],
771                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
772                };
773                let result = self.bitwise(vs1);
774                let result = CpuStorage::I32(result);
775                Ok((result, l1.shape().clone()))
776            }
777            CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")),
778            CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")),
779            CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")),
780            CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")),
781            CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")),
782        }
783    }
784
785    #[cfg(feature = "cuda")]
786    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
787        todo!()
788    }
789
790    #[cfg(feature = "metal")]
791    fn metal_fwd(
792        &self,
793        s1: &candle_core::MetalStorage,
794        l1: &Layout,
795    ) -> Result<(candle_core::MetalStorage, Shape)> {
796        if !l1.is_contiguous() {
797            candle_core::bail!("Input tensor s1 must be contiguous");
798        }
799
800        let command_buffer = s1.device().command_buffer()?;
801        command_buffer.set_label("bitwise-unary-op");
802
803        let device = s1.device();
804
805        let out_shape = l1.shape().clone();
806
807        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
808
809        match self.op {
810            BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
811                device.device(),
812                &command_buffer,
813                &crate::metal_kernels::Kernels::new(),
814                s1.dtype(),
815                s1.buffer(),
816                l1.start_offset() * s1.dtype().size_in_bytes(),
817                out_shape.elem_count(),
818                &output,
819            )
820            .map_err(candle_core::Error::wrap)?,
821        }
822
823        let newstorage = candle_core::MetalStorage::new(
824            output,
825            device.clone(),
826            out_shape.elem_count(),
827            s1.dtype(),
828        );
829        Ok((newstorage, out_shape))
830    }
831}
832
833#[allow(dead_code)]
834pub trait BitWiseOp {
835    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
836    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
837    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
838    fn bitwise_not(&self) -> Result<Tensor>;
839}
840
841impl BitWiseOp for Tensor {
842    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
843        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
844    }
845
846    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
847        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
848    }
849
850    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
851        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
852    }
853
854    fn bitwise_not(&self) -> Result<Tensor> {
855        self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
856    }
857}
858
859// ────────────────────────────── ArgSort / Sort ────────────────────────────────
860
861#[allow(unused)]
862/// Configuration for an **argsort** (returns indices) operation.
863struct ArgSort {
864    axis: usize,
865}
866
867#[allow(unused)]
868/// Configuration for a **sort** (returns re‑ordered values) operation.
869struct Sort {
870    axis: usize,
871}
872
873impl CustomOp1 for ArgSort {
874    fn name(&self) -> &'static str {
875        "argsort"
876    }
877
878    // -------- CPU ------------------------------------------------------------
879    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
880        candle_core::bail!("ArgSort is not implemented for the CPU backend");
881    }
882
883    // -------- CUDA -----------------------------------------------------------
884    #[cfg(feature = "cuda")]
885    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
886        candle_core::bail!("ArgSort is not implemented for the CUDA backend");
887    }
888
889    // -------- Metal ----------------------------------------------------------
890    #[cfg(feature = "metal")]
891    fn metal_fwd(
892        &self,
893        s1: &candle_core::MetalStorage,
894        l1: &Layout,
895    ) -> Result<(candle_core::MetalStorage, Shape)> {
896        // Require contiguous input (same as other metal ops in this file)
897        if !l1.is_contiguous() {
898            candle_core::bail!("Input tensor s1 must be contiguous");
899        }
900
901        // Create a command‑buffer and label it for easy debugging in Xcode’s GPU frame‑capture
902        let command_buffer = s1.device().command_buffer()?;
903        command_buffer.set_label("argsort");
904
905        let device = s1.device();
906        let out_shape = l1.shape().clone();
907        let elem_count = out_shape.elem_count();
908
909        // Output buffer holds the sorted indices → always `U32`
910        let output = device.new_buffer(elem_count, candle_core::DType::U32, "argsort")?;
911
912        // ------------------------------------------------------------------
913        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
914        // ------------------------------------------------------------------
915        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
916
917        let dims = l1.dims();
918        let size_sorted_axis = dims[self.axis];
919        let n_rows = l1.shape().elem_count() / size_sorted_axis;
920
921        // Replicate the kernel’s internal block sizing to derive `n_blocks`
922        let tn = 4usize;
923        let mut bn = match size_sorted_axis.div_ceil(tn) {
924            v if v > 256 => 512,
925            v if v > 128 => 256,
926            v if v > 64 => 128,
927            v if v > 32 => 64,
928            _ => 32,
929        };
930        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
931            bn = 256;
932        }
933        let n_per_block = bn * tn;
934        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
935
936        // Borrow the buffers for this launch
937        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
938
939        // ------------------------------------------------------------------
940        // Build the unified SortArgs payload
941        // ------------------------------------------------------------------
942        let sort_args = crate::metal_kernels::SortArgs {
943            axis: self.axis,
944            shape: l1.dims(),
945            strides: l1.stride(),
946            out_shape: l1.dims(), // same as input for argsort
947            out_strides: l1.stride(),
948            in_contiguous: l1.is_contiguous(),
949            in_ty: s1.dtype(),
950            out_ty: candle_core::DType::U32,
951            src: s1.buffer(),
952            src_offset: l1.start_offset(), // element offset
953            dst: &output,
954            bn,
955            tn,
956            n_blocks,
957        };
958
959        // Launch the Metal kernel via the new API
960        crate::metal_kernels::call_argsort(
961            device.device(), // &metal::Device
962            &command_buffer, // impl EncoderProvider
963            &crate::metal_kernels::Kernels::new(),
964            &sort_args,
965            &scratch,
966        )
967        .map_err(candle_core::Error::wrap)?;
968
969        // Wrap and return as a new MetalStorage
970        let newstorage = candle_core::MetalStorage::new(
971            output,
972            device.clone(),
973            elem_count,
974            candle_core::DType::U32,
975        );
976        Ok((newstorage, out_shape))
977    }
978}
979
980impl CustomOp1 for Sort {
981    fn name(&self) -> &'static str {
982        "sort"
983    }
984
985    // -------- CPU ------------------------------------------------------------
986    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
987        candle_core::bail!("Sort is not implemented for the CPU backend");
988    }
989
990    // -------- CUDA -----------------------------------------------------------
991    #[cfg(feature = "cuda")]
992    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
993        candle_core::bail!("Sort is not implemented for the CUDA backend");
994    }
995
996    // -------- Metal ----------------------------------------------------------
997    #[cfg(feature = "metal")]
998    fn metal_fwd(
999        &self,
1000        s1: &candle_core::MetalStorage,
1001        l1: &Layout,
1002    ) -> Result<(candle_core::MetalStorage, Shape)> {
1003        // Require contiguous input (same as other metal ops in this file)
1004        if !l1.is_contiguous() {
1005            candle_core::bail!("Input tensor s1 must be contiguous");
1006        }
1007
1008        // Create a command‑buffer and label it for easy debugging in Xcode’s GPU frame‑capture
1009        let command_buffer = s1.device().command_buffer()?;
1010        command_buffer.set_label("sort");
1011
1012        let device = s1.device();
1013        let out_shape = l1.shape().clone();
1014        let elem_count = out_shape.elem_count();
1015
1016        // Output buffer keeps the same dtype as the input (these are the reordered values)
1017        let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
1018
1019        // ------------------------------------------------------------------
1020        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
1021        // ------------------------------------------------------------------
1022        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
1023
1024        let dims = l1.dims();
1025        let size_sorted_axis = dims[self.axis];
1026        let n_rows = l1.shape().elem_count() / size_sorted_axis;
1027
1028        // Replicate the kernel’s internal block sizing to derive `n_blocks`
1029        let tn = 4usize;
1030        let mut bn = match size_sorted_axis.div_ceil(tn) {
1031            v if v > 256 => 512,
1032            v if v > 128 => 256,
1033            v if v > 64 => 128,
1034            v if v > 32 => 64,
1035            _ => 32,
1036        };
1037        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
1038            bn = 256;
1039        }
1040        let n_per_block = bn * tn;
1041        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
1042
1043        // Borrow the buffers for this launch
1044        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
1045
1046        // ------------------------------------------------------------------
1047        // Build the unified SortArgs payload
1048        // ------------------------------------------------------------------
1049        let sort_args = crate::metal_kernels::SortArgs {
1050            axis: self.axis,
1051            shape: l1.dims(),
1052            strides: l1.stride(),
1053            out_shape: l1.dims(), // same shape for value sort
1054            out_strides: l1.stride(),
1055            in_contiguous: l1.is_contiguous(),
1056            in_ty: s1.dtype(),
1057            out_ty: s1.dtype(),
1058            src: s1.buffer(),
1059            src_offset: l1.start_offset(), // element offset
1060            dst: &output,
1061            bn,
1062            tn,
1063            n_blocks,
1064        };
1065
1066        // Launch the Metal kernel via the new API
1067        crate::metal_kernels::call_sort(
1068            device.device(), // &metal::Device
1069            &command_buffer, // impl EncoderProvider
1070            &crate::metal_kernels::Kernels::new(),
1071            &sort_args,
1072            &scratch,
1073        )
1074        .map_err(candle_core::Error::wrap)?;
1075
1076        // Wrap and return as a new MetalStorage
1077        let newstorage =
1078            candle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1079        Ok((newstorage, out_shape))
1080    }
1081}
1082
1083/// Extension trait adding `argsort` / `sort` convenience calls on `Tensor`.
1084pub trait SortOp {
1085    /// Returns the indices that would (ascending) sort the tensor along `axis`.
1086    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1087    /// Returns the tensor's values (ascending) sorted along `axis`.
1088    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1089}
1090
1091impl SortOp for Tensor {
1092    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1093        if self.device().is_cpu() || self.device().is_cuda() {
1094            return self.arg_sort_last_dim(true);
1095        }
1096        self.apply_op1_no_bwd(&ArgSort {
1097            axis: axis.to_index(self.shape(), "argsort")?,
1098        })
1099    }
1100
1101    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1102        if self.device().is_cpu() || self.device().is_cuda() {
1103            return Ok(self.sort_last_dim(true)?.0);
1104        }
1105        self.apply_op1_no_bwd(&Sort {
1106            axis: axis.to_index(self.shape(), "sort")?,
1107        })
1108    }
1109}
1110
1111struct NonZero;
1112
1113impl NonZero {
1114    // Sequential version
1115    fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1116        let n = layout.dims().len();
1117        let mut result = Vec::new();
1118        let mut indices = vec![0u32; n];
1119        for (i, v) in vs.iter().enumerate() {
1120            if !v.is_zero() {
1121                let mut idx = i;
1122                for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1123                    let d = idx % dim;
1124                    indices[dim_index] = u32::try_from(d).unwrap();
1125                    idx /= dim;
1126                }
1127                result.extend_from_slice(&indices);
1128            }
1129        }
1130        result
1131    }
1132}
1133
1134#[cfg(feature = "cuda")]
1135fn count_nonzero_cuda(
1136    dtype: candle_core::DType,
1137    d_in: *const c_void,
1138    n: u32,
1139    stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1140) -> u32 {
1141    unsafe {
1142        match dtype {
1143            candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1144            candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1145            candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1146            candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1147            candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1148            candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1149            candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1150            candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1151            candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1152            candle_core::DType::F8E4M3 => todo!(),
1153        }
1154    }
1155}
1156
1157#[allow(clippy::too_many_arguments)]
1158#[cfg(feature = "cuda")]
1159fn nonzero_cuda(
1160    dtype: candle_core::DType,
1161    d_in: *const c_void,
1162    n: u32,
1163    num_nonzero: u32,
1164    dims: *const c_void,
1165    num_dims: u32,
1166    d_out: *mut c_void,
1167    stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1168) {
1169    unsafe {
1170        match dtype {
1171            candle_core::DType::U8 => {
1172                ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1173            }
1174            candle_core::DType::U32 => {
1175                ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1176            }
1177            candle_core::DType::I64 => {
1178                ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1179            }
1180            candle_core::DType::I32 => {
1181                ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1182            }
1183            candle_core::DType::I16 => {
1184                ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1185            }
1186            candle_core::DType::BF16 => {
1187                ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1188            }
1189            candle_core::DType::F16 => {
1190                ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1191            }
1192            candle_core::DType::F32 => {
1193                ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1194            }
1195            candle_core::DType::F64 => {
1196                ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1197            }
1198            candle_core::DType::F8E4M3 => todo!(),
1199        }
1200    }
1201}
1202
1203impl CustomOp1 for NonZero {
1204    fn name(&self) -> &'static str {
1205        "nonzero"
1206    }
1207
1208    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1209        if !layout.is_contiguous() {
1210            return Err(Error::RequiresContiguous { op: "nonzero" });
1211        }
1212        let result = match storage {
1213            candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1214            candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1215            candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1216            candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1217            candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1218            candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1219            candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1220            candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1221            candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1222            candle_core::CpuStorage::F8E4M3(_vs) => todo!(),
1223        };
1224        let index_len = layout.dims().len();
1225        let result_len = result.len() / index_len;
1226        let result = CpuStorage::U32(result);
1227        let shape = Shape::from_dims(&[result_len, index_len]);
1228        Ok((result, shape))
1229    }
1230    #[cfg(feature = "cuda")]
1231    fn cuda_fwd(
1232        &self,
1233        storage: &candle_core::CudaStorage,
1234        layout: &Layout,
1235    ) -> Result<(candle_core::CudaStorage, Shape)> {
1236        if !layout.is_contiguous() {
1237            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1238        }
1239        let dev = storage.device().clone();
1240        let d_in = match storage.dtype() {
1241            candle_core::DType::U8 => *storage.as_cuda_slice::<u8>()?.device_ptr(),
1242            candle_core::DType::U32 => *storage.as_cuda_slice::<u32>()?.device_ptr(),
1243            candle_core::DType::I32 => *storage.as_cuda_slice::<i32>()?.device_ptr(),
1244            candle_core::DType::I16 => *storage.as_cuda_slice::<i16>()?.device_ptr(),
1245            candle_core::DType::I64 => *storage.as_cuda_slice::<i64>()?.device_ptr(),
1246            candle_core::DType::BF16 => *storage.as_cuda_slice::<half::bf16>()?.device_ptr(),
1247            candle_core::DType::F16 => *storage.as_cuda_slice::<half::f16>()?.device_ptr(),
1248            candle_core::DType::F32 => *storage.as_cuda_slice::<f32>()?.device_ptr(),
1249            candle_core::DType::F64 => *storage.as_cuda_slice::<f64>()?.device_ptr(),
1250            candle_core::DType::F8E4M3 => todo!(),
1251        } as *const c_void;
1252        let n = layout.shape().elem_count();
1253
1254        let num_nonzero =
1255            count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?, *dev.cu_stream());
1256        let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1257            .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1258        if num_nonzero != 0 {
1259            let d_out_ptr = *d_out.device_ptr() as *mut c_void;
1260            let dims = layout
1261                .dims()
1262                .iter()
1263                .map(|&x| u32::try_from(x).unwrap())
1264                .collect::<Vec<u32>>();
1265            let d_dims = dev
1266                .htod_copy(dims)
1267                .map_err(|_| Error::Msg("Failed to copy dims to device".to_string()))?;
1268            let d_dims_ptr = *d_dims.device_ptr() as *const c_void;
1269            nonzero_cuda(
1270                storage.dtype(),
1271                d_in,
1272                u32::try_from(n)?,
1273                num_nonzero,
1274                d_dims_ptr,
1275                u32::try_from(layout.dims().len())?,
1276                d_out_ptr,
1277                *dev.cu_stream(),
1278            );
1279        }
1280        let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1281        let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
1282        Ok((dst, shape))
1283    }
1284}
1285
1286pub trait NonZeroOp {
1287    fn nonzero(&self) -> Result<Tensor>;
1288}
1289
1290impl NonZeroOp for Tensor {
1291    #[cfg(feature = "metal")]
1292    fn nonzero(&self) -> Result<Tensor> {
1293        if !self.is_contiguous() {
1294            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1295        }
1296        let original_device = self.device();
1297        self.to_device(&candle_core::Device::Cpu)?
1298            .apply_op1_no_bwd(&NonZero)?
1299            .to_device(original_device)
1300    }
1301
1302    #[cfg(not(feature = "metal"))]
1303    fn nonzero(&self) -> Result<Tensor> {
1304        if !self.is_contiguous() {
1305            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1306        }
1307        self.apply_op1_no_bwd(&NonZero)
1308    }
1309}
1310
1311struct CumSum {
1312    inclusive: bool,
1313    reverse: bool,
1314    axis: usize,
1315}
1316
1317impl CustomOp1 for CumSum {
1318    fn name(&self) -> &'static str {
1319        "cumsum"
1320    }
1321
1322    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1323        use std::ops::Add;
1324        if !l1.is_contiguous() {
1325            candle_core::bail!("Input tensor s1 must be contiguous");
1326        }
1327        let dims = l1.dims();
1328        let axis = self.axis;
1329        let axis_len = dims[axis];
1330        let (start, end) = l1
1331            .contiguous_offsets()
1332            .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1333
1334        // helper to execute scan for a slice of T
1335        macro_rules! scan_block {
1336            ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1337                let vs: &[$ty] = $vt;
1338                let input = &vs[start..end];
1339                let count = input.len() / axis_len;
1340                let mut result = Vec::<$ty>::with_capacity(input.len());
1341                if !self.reverse {
1342                    if self.inclusive {
1343                        for block in 0..count {
1344                            let base = block * axis_len;
1345                            let mut sum = input[base];
1346                            result.push(sum);
1347                            for j in 1..axis_len {
1348                                sum = sum.$add(input[base + j]);
1349                                result.push(sum);
1350                            }
1351                        }
1352                    } else {
1353                        let init: $ty = $init;
1354                        for block in 0..count {
1355                            let base = block * axis_len;
1356                            let mut sum = init;
1357                            for j in 0..axis_len {
1358                                result.push(sum);
1359                                sum = sum.$add(input[base + j]);
1360                            }
1361                        }
1362                    }
1363                } else {
1364                    if self.inclusive {
1365                        for block in 0..count {
1366                            let base = block * axis_len;
1367                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1368                            let mut sum = input[base + axis_len - 1];
1369                            temp.push(sum);
1370                            for k in 1..axis_len {
1371                                let idx = axis_len - 1 - k;
1372                                sum = sum.$add(input[base + idx]);
1373                                temp.push(sum);
1374                            }
1375                            temp.reverse();
1376                            result.extend(temp);
1377                        }
1378                    } else {
1379                        let init: $ty = $init;
1380                        for block in 0..count {
1381                            let base = block * axis_len;
1382                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1383                            let mut sum = init;
1384                            for k in 0..axis_len {
1385                                let idx = axis_len - 1 - k;
1386                                temp.push(sum);
1387                                sum = sum.$add(input[base + idx]);
1388                            }
1389                            temp.reverse();
1390                            result.extend(temp);
1391                        }
1392                    }
1393                }
1394                result
1395            }};
1396        }
1397        match s1 {
1398            CpuStorage::U8(vs) => {
1399                let result = scan_block!(vs, u8, wrapping_add, 0u8);
1400                Ok((CpuStorage::U8(result), l1.shape().clone()))
1401            }
1402            CpuStorage::I16(vs) => {
1403                let result = scan_block!(vs, i16, add, 0i16);
1404                Ok((CpuStorage::I16(result), l1.shape().clone()))
1405            }
1406            CpuStorage::U32(vs) => {
1407                let result = scan_block!(vs, u32, wrapping_add, 0u32);
1408                Ok((CpuStorage::U32(result), l1.shape().clone()))
1409            }
1410            CpuStorage::I32(vs) => {
1411                let result = scan_block!(vs, i32, add, 0i32);
1412                Ok((CpuStorage::I32(result), l1.shape().clone()))
1413            }
1414            CpuStorage::I64(vs) => {
1415                let result = scan_block!(vs, i64, add, 0i64);
1416                Ok((CpuStorage::I64(result), l1.shape().clone()))
1417            }
1418            CpuStorage::F32(vs) => {
1419                let result = scan_block!(vs, f32, add, 0.0f32);
1420                Ok((CpuStorage::F32(result), l1.shape().clone()))
1421            }
1422            CpuStorage::F64(vs) => {
1423                let result = scan_block!(vs, f64, add, 0.0f64);
1424                Ok((CpuStorage::F64(result), l1.shape().clone()))
1425            }
1426            _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1427        }
1428    }
1429
1430    #[cfg(feature = "cuda")]
1431    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1432        todo!()
1433    }
1434
1435    #[cfg(feature = "metal")]
1436    fn metal_fwd(
1437        &self,
1438        s1: &candle_core::MetalStorage,
1439        l1: &Layout,
1440    ) -> Result<(candle_core::MetalStorage, Shape)> {
1441        use crate::metal_kernels::ScanType;
1442
1443        let command_buffer = s1.device().command_buffer()?;
1444        command_buffer.set_label("cumsum");
1445
1446        let device = s1.device();
1447
1448        let out_shape = l1.shape().clone();
1449
1450        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1451
1452        crate::metal_kernels::call_scan(
1453            device.device(),
1454            &command_buffer,
1455            &crate::metal_kernels::Kernels::new(),
1456            s1.dtype(),
1457            ScanType::Sum,
1458            s1.buffer(),
1459            l1.start_offset() * s1.dtype().size_in_bytes(),
1460            self.axis,
1461            l1.dims(),
1462            l1.stride(),
1463            self.reverse,
1464            self.inclusive,
1465            &output,
1466        )
1467        .map_err(candle_core::Error::wrap)?;
1468
1469        let newstorage = candle_core::MetalStorage::new(
1470            output,
1471            device.clone(),
1472            out_shape.elem_count(),
1473            s1.dtype(),
1474        );
1475        Ok((newstorage, out_shape))
1476    }
1477}
1478
1479#[allow(dead_code)]
1480pub trait CumSumOp {
1481    /// inclusive = false, reverse = false
1482    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1483
1484    fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1485        -> Result<Tensor>;
1486}
1487
1488impl CumSumOp for Tensor {
1489    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1490        self.fast_cumsum_config(axis, false, false)
1491    }
1492
1493    fn fast_cumsum_config<D: Dim>(
1494        &self,
1495        axis: D,
1496        inclusive: bool,
1497        reverse: bool,
1498    ) -> Result<Tensor> {
1499        self.apply_op1_no_bwd(&CumSum {
1500            inclusive,
1501            reverse,
1502            axis: axis.to_index(self.shape(), "cumsum")?,
1503        })
1504    }
1505}
1506
1507mod tests {
1508    #[test]
1509    fn test_cumsum_exclusive_forward_cpu() {
1510        use crate::utils::ops::CumSumOp;
1511        use candle_core::Tensor;
1512        let device = candle_core::Device::Cpu;
1513        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1514        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1515        assert_eq!(b, [0, 1, 3, 6]);
1516    }
1517
1518    #[test]
1519    fn test_cumsum_inclusive_forward_cpu() {
1520        use crate::utils::ops::CumSumOp;
1521        use candle_core::Tensor;
1522        let device = candle_core::Device::Cpu;
1523        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1524        let b = a
1525            .fast_cumsum_config(0, true, false)
1526            .unwrap()
1527            .to_vec1::<i64>()
1528            .unwrap();
1529        assert_eq!(b, [1, 3, 6, 10]);
1530    }
1531
1532    #[test]
1533    fn test_cumsum_exclusive_reverse_cpu() {
1534        use crate::utils::ops::CumSumOp;
1535        use candle_core::Tensor;
1536        let device = candle_core::Device::Cpu;
1537        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1538        let b = a
1539            .fast_cumsum_config(0, false, true)
1540            .unwrap()
1541            .to_vec1::<i64>()
1542            .unwrap();
1543        assert_eq!(b, [9, 7, 4, 0]);
1544    }
1545
1546    #[test]
1547    fn test_cumsum_inclusive_reverse_cpu() {
1548        use crate::utils::ops::CumSumOp;
1549        use candle_core::Tensor;
1550        let device = candle_core::Device::Cpu;
1551        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1552        let b = a
1553            .fast_cumsum_config(0, true, true)
1554            .unwrap()
1555            .to_vec1::<i64>()
1556            .unwrap();
1557        assert_eq!(b, [10, 9, 7, 4]);
1558    }
1559
1560    #[cfg(feature = "metal")]
1561    #[test]
1562    fn test_cumsum_exclusive_forward_metal() {
1563        use crate::utils::ops::CumSumOp;
1564        use candle_core::Tensor;
1565        let device = candle_core::Device::new_metal(0).unwrap();
1566        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1567        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1568        assert_eq!(b, [0, 1, 3, 6]);
1569    }
1570
1571    #[cfg(feature = "metal")]
1572    #[test]
1573    fn test_cumsum_inclusive_forward_metal() {
1574        use crate::utils::ops::CumSumOp;
1575        use candle_core::Tensor;
1576        let device = candle_core::Device::new_metal(0).unwrap();
1577        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1578        let b = a
1579            .fast_cumsum_config(0, true, false)
1580            .unwrap()
1581            .to_vec1::<i64>()
1582            .unwrap();
1583        assert_eq!(b, [1, 3, 6, 10]);
1584    }
1585
1586    #[cfg(feature = "metal")]
1587    #[test]
1588    fn test_cumsum_exclusive_reverse_metal() {
1589        use crate::utils::ops::CumSumOp;
1590        use candle_core::Tensor;
1591        let device = candle_core::Device::new_metal(0).unwrap();
1592        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1593        let b = a
1594            .fast_cumsum_config(0, false, true)
1595            .unwrap()
1596            .to_vec1::<i64>()
1597            .unwrap();
1598        assert_eq!(b, [9, 7, 4, 0]);
1599    }
1600
1601    #[cfg(feature = "metal")]
1602    #[test]
1603    fn test_cumsum_inclusive_reverse_metal() {
1604        use crate::utils::ops::CumSumOp;
1605        use candle_core::Tensor;
1606        let device = candle_core::Device::new_metal(0).unwrap();
1607        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1608        let b = a
1609            .fast_cumsum_config(0, true, true)
1610            .unwrap()
1611            .to_vec1::<i64>()
1612            .unwrap();
1613        assert_eq!(b, [10, 9, 7, 4]);
1614    }
1615
1616    #[test]
1617    fn test_nonzero_cpu() {
1618        use crate::utils::ops::NonZeroOp;
1619        use candle_core::Tensor;
1620        let device = candle_core::Device::Cpu;
1621        let a = Tensor::from_vec(
1622            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1623            &[2, 4],
1624            &device,
1625        )
1626        .unwrap();
1627        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1628        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1629    }
1630
1631    #[cfg(feature = "cuda")]
1632    #[test]
1633    fn test_nonzero_cuda() {
1634        use crate::utils::ops::NonZeroOp;
1635        use candle_core::Tensor;
1636        let device = candle_core::Device::new_cuda(0).unwrap();
1637        let a = Tensor::from_vec(
1638            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1639            &[2, 4],
1640            &device,
1641        )
1642        .unwrap();
1643        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1644        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1645    }
1646
1647    #[test]
1648    fn test_bitwise_and_cpu() {
1649        use crate::utils::ops::BitWiseOp;
1650        use candle_core::Tensor;
1651        let device = candle_core::Device::Cpu;
1652        let a =
1653            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1654        let b =
1655            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1656        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1657        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
1658    }
1659
1660    #[cfg(feature = "cuda")]
1661    #[test]
1662    fn test_bitwise_and_cuda() {
1663        use crate::utils::ops::BitWiseOp;
1664        use candle_core::Tensor;
1665        let device = candle_core::Device::new_cuda(0).unwrap();
1666        let a =
1667            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1668        let b =
1669            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
1670        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1671        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
1672    }
1673
1674    #[test]
1675    fn test_bitwise_or_cpu() {
1676        use crate::utils::ops::BitWiseOp;
1677        use candle_core::Tensor;
1678        let device = candle_core::Device::Cpu;
1679        let a =
1680            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1681        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1682        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1683        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1684    }
1685
1686    #[cfg(feature = "cuda")]
1687    #[test]
1688    fn test_bitwise_or_cuda() {
1689        use crate::utils::ops::BitWiseOp;
1690        use candle_core::Tensor;
1691        let device = candle_core::Device::new_cuda(0).unwrap();
1692        let a =
1693            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1694        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1695        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1696        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1697    }
1698
1699    #[test]
1700    fn test_bitwise_xor_cpu() {
1701        use crate::utils::ops::BitWiseOp;
1702        use candle_core::Tensor;
1703        let device = candle_core::Device::Cpu;
1704        let a =
1705            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1706        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1707        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1708        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1709    }
1710
1711    #[cfg(feature = "cuda")]
1712    #[test]
1713    fn test_bitwise_xor_cuda() {
1714        use crate::utils::ops::BitWiseOp;
1715        use candle_core::Tensor;
1716        let device = candle_core::Device::new_cuda(0).unwrap();
1717        let a =
1718            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1719        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1720        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1721        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1722    }
1723
1724    #[test]
1725    fn test_nonzero_and() {
1726        use crate::utils::ops::{BitWiseOp, NonZeroOp};
1727        use candle_core::{Device, Tensor};
1728
1729        let input1 = Tensor::from_vec(
1730            vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1731            (10,),
1732            &Device::Cpu,
1733        )
1734        .unwrap();
1735        let input2 = Tensor::from_vec(
1736            vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1737            (10,),
1738            &Device::Cpu,
1739        )
1740        .unwrap();
1741        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1742
1743        let lt = input.lt(0.0).unwrap();
1744        let gt = input.gt(-10.0).unwrap();
1745        let res = lt
1746            .bitwise_and(&gt)
1747            .unwrap()
1748            .nonzero()
1749            .unwrap()
1750            .to_vec2::<u32>()
1751            .unwrap();
1752
1753        assert_eq!(
1754            res,
1755            [
1756                [0, 3],
1757                [0, 4],
1758                [0, 5],
1759                [0, 6],
1760                [1, 0],
1761                [1, 3],
1762                [1, 5],
1763                [1, 6]
1764            ]
1765        );
1766    }
1767
1768    #[cfg(feature = "cuda")]
1769    #[test]
1770    fn nonzero_and_cuda() {
1771        use crate::utils::ops::{BitWiseOp, NonZeroOp};
1772        use candle_core::{Device, Tensor};
1773
1774        let device = Device::new_cuda(0).unwrap();
1775        let input1 =
1776            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1777        let input2 =
1778            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1779        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1780
1781        let lt = input.lt(0.0).unwrap();
1782        let gt = input.gt(-10.0).unwrap();
1783        let res = lt
1784            .bitwise_and(&gt)
1785            .unwrap()
1786            .nonzero()
1787            .unwrap()
1788            .to_vec2::<u32>()
1789            .unwrap();
1790
1791        assert_eq!(
1792            res,
1793            [
1794                [0, 3],
1795                [0, 4],
1796                [0, 5],
1797                [0, 6],
1798                [1, 0],
1799                [1, 3],
1800                [1, 5],
1801                [1, 6]
1802            ]
1803        );
1804    }
1805
1806    #[test]
1807    fn test_bitpack_8bit_cpu() {
1808        use crate::HqqBits;
1809        use candle_core::{Device, Tensor};
1810        let bits = HqqBits::Eight;
1811        let device = Device::Cpu;
1812        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1813        let c = bits.bitpack_type()(wq.clone())
1814            .unwrap()
1815            .to_vec2::<u8>()
1816            .unwrap();
1817        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1818    }
1819
1820    #[cfg(feature = "cuda")]
1821    #[test]
1822    fn test_bitpack_8bit_cuda() {
1823        use crate::HqqBits;
1824        use candle_core::DType;
1825        use candle_core::{Device, Tensor};
1826        let bits = HqqBits::Eight;
1827        let device = Device::new_cuda(0).unwrap();
1828        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1829        let c = bits.bitpack_type()(wq.clone())
1830            .unwrap()
1831            .to_dtype(DType::U8)
1832            .unwrap()
1833            .to_vec2::<u8>()
1834            .unwrap();
1835        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1836    }
1837
1838    #[cfg(feature = "metal")]
1839    #[test]
1840    fn test_bitpack_8bit_metal() {
1841        use crate::HqqBits;
1842        use candle_core::{Device, Tensor};
1843        let bits = HqqBits::Eight;
1844        let device = Device::new_metal(0).unwrap();
1845        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1846        let c = bits.bitpack_type()(wq.clone())
1847            .unwrap()
1848            .to_vec2::<u8>()
1849            .unwrap();
1850        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1851    }
1852
1853    #[test]
1854    fn test_bitpack_4bit() {
1855        use crate::HqqBits;
1856        use candle_core::{Device, Tensor};
1857        let bits = HqqBits::Four;
1858        let device = Device::Cpu;
1859        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1860        let c = bits.bitpack_type()(wq.clone())
1861            .unwrap()
1862            .to_vec2::<u8>()
1863            .unwrap();
1864        assert_eq!(c, [[19, 36]]);
1865    }
1866
1867    #[cfg(feature = "cuda")]
1868    #[test]
1869    fn test_bitpack_4bit_cuda() {
1870        use crate::HqqBits;
1871        use candle_core::{Device, Tensor};
1872        let bits = HqqBits::Four;
1873        let device = Device::new_cuda(0).unwrap();
1874        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1875        let c = bits.bitpack_type()(wq.clone())
1876            .unwrap()
1877            .to_vec2::<u8>()
1878            .unwrap();
1879        assert_eq!(c, [[19, 36]]);
1880    }
1881
1882    #[cfg(feature = "metal")]
1883    #[test]
1884    fn test_bitpack_4bit_metal() {
1885        use crate::HqqBits;
1886        use candle_core::{Device, Tensor};
1887        let bits = HqqBits::Four;
1888        let device = Device::new_metal(0).unwrap();
1889        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1890        let c = bits.bitpack_type()(wq.clone())
1891            .unwrap()
1892            .to_vec2::<u8>()
1893            .unwrap();
1894        assert_eq!(c, [[19, 36]]);
1895    }
1896    // ─────────────────────────────── Sort / ArgSort ────────────────────────────────
1897    #[cfg(feature = "metal")]
1898    #[test]
1899    fn test_sort_and_argsort_vector_metal() {
1900        use crate::utils::ops::SortOp;
1901        use candle_core::Tensor;
1902
1903        let device = candle_core::Device::new_metal(0).unwrap();
1904        let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
1905
1906        // sort (ascending)
1907        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1908        assert_eq!(sorted, [1, 2, 3, 4]);
1909
1910        // argsort (ascending indices)
1911        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1912        assert_eq!(idx, [1, 3, 0, 2]);
1913    }
1914
1915    #[cfg(feature = "metal")]
1916    #[test]
1917    fn test_sort_and_argsort_matrix_axis1_metal() {
1918        use crate::utils::ops::SortOp;
1919        use candle_core::Tensor;
1920
1921        let device = candle_core::Device::new_metal(0).unwrap();
1922        // 2 × 3 matrix:
1923        // [[3, 1, 2],
1924        //  [0, 4, 5]]
1925        let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
1926
1927        // Sort along axis=1 (second dimension)
1928        let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
1929        assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
1930
1931        // ArgSort indices along axis=1
1932        let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
1933        assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
1934    }
1935
1936    // ─────────────────────────────── 2 048-element vector ────────────────────────────────
1937    #[cfg(feature = "metal")]
1938    #[test]
1939    fn test_sort_and_argsort_vector_2048_metal() {
1940        use crate::utils::ops::SortOp;
1941        use candle_core::Tensor;
1942
1943        const N: usize = 4096;
1944
1945        let device = candle_core::Device::new_metal(0).expect("Metal device");
1946
1947        // Create a descending vector [4095, 4094, …, 0]
1948        let vals: Vec<i32> = (0..N as i32).rev().collect();
1949        let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
1950
1951        // ---- sort (ascending) ---------------------------------------------------------
1952        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1953        let expected: Vec<i32> = (0..N as i32).collect();
1954        assert_eq!(sorted, expected);
1955
1956        // ---- argsort (indices that would sort) ---------------------------------------
1957        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1958        // Because the input is reversed, the correct indices are likewise reversed
1959        for (i, &v) in idx.iter().enumerate() {
1960            assert_eq!(v as usize, N - 1 - i);
1961        }
1962    }
1963}