mistralrs_quant/utils/
ops.rs

1use candle_core::{
2    backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3    Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8    fmt::Display,
9    ops::{BitAnd, BitOr, BitXor, Not, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::{ffi, slice_ptr};
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19#[cfg(feature = "metal")]
20use crate::metal_kernels::SortScratchCache; // re‑export for clarity
21#[cfg(feature = "metal")]
22use std::sync::OnceLock;
23
24#[cfg(feature = "metal")]
25static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
26
27struct Leftshift(usize);
28
29impl Leftshift {
30    fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
31        let offset = T::from_f64(self.0 as f64);
32        vs.into_par_iter().map(|v| *v << offset).collect()
33    }
34}
35
36impl CustomOp1 for Leftshift {
37    fn name(&self) -> &'static str {
38        "left"
39    }
40
41    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
42        match s1 {
43            CpuStorage::U8(vs1) => {
44                let vs1 = match l1.contiguous_offsets() {
45                    Some((a, b)) => &vs1[a..b],
46                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
47                };
48                let result = self.leftshift(vs1);
49                let result = CpuStorage::U8(result);
50                Ok((result, l1.shape().clone()))
51            }
52            CpuStorage::I16(vs1) => {
53                let vs1 = match l1.contiguous_offsets() {
54                    Some((a, b)) => &vs1[a..b],
55                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
56                };
57                let result = self.leftshift(vs1);
58                let result = CpuStorage::I16(result);
59                Ok((result, l1.shape().clone()))
60            }
61            CpuStorage::U32(vs1) => {
62                let vs1 = match l1.contiguous_offsets() {
63                    Some((a, b)) => &vs1[a..b],
64                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
65                };
66                let result = self.leftshift(vs1);
67                let result = CpuStorage::U32(result);
68                Ok((result, l1.shape().clone()))
69            }
70            CpuStorage::I64(vs1) => {
71                let vs1 = match l1.contiguous_offsets() {
72                    Some((a, b)) => &vs1[a..b],
73                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
74                };
75                let result = self.leftshift(vs1);
76                let result = CpuStorage::I64(result);
77                Ok((result, l1.shape().clone()))
78            }
79            CpuStorage::I32(vs1) => {
80                let vs1 = match l1.contiguous_offsets() {
81                    Some((a, b)) => &vs1[a..b],
82                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
83                };
84                let result = self.leftshift(vs1);
85                let result = CpuStorage::I32(result);
86                Ok((result, l1.shape().clone()))
87            }
88            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "leftshift")),
89        }
90    }
91
92    #[cfg(feature = "cuda")]
93    fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
94        if !l1.is_contiguous() {
95            candle_core::bail!("Input tensor s1 must be contiguous");
96        }
97        let dev = s1.device().clone();
98        let (d_in1_ptr, _d_guard, elem_count) = match s1.dtype() {
99            DType::U8 => {
100                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
101                let elem_count = l1.shape().elem_count();
102                (d_in1 as *const c_void, d_in1_guard, elem_count)
103            }
104            DType::I32 => {
105                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
106                let elem_count = l1.shape().elem_count();
107                (d_in1 as *const c_void, d_in1_guard, elem_count)
108            }
109            other => {
110                return Err(Error::UnsupportedDTypeForOp(other, "leftshift"));
111            }
112        };
113        let dst = match s1.dtype() {
114            DType::U8 => {
115                let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
116                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
117                unsafe {
118                    ffi::leftshift_u8(
119                        d_in1_ptr,
120                        d_out_ptr as *mut std::ffi::c_void,
121                        u32::try_from(elem_count)?,
122                        self.0 as i32,
123                    )
124                };
125                drop(d_out_guard);
126                CudaStorage::wrap_cuda_slice(d_out, dev)
127            }
128            DType::I32 => {
129                let d_out = unsafe { dev.alloc::<i32>(elem_count) }?;
130                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
131                unsafe {
132                    ffi::leftshift_i32(
133                        d_in1_ptr,
134                        d_out_ptr as *mut std::ffi::c_void,
135                        u32::try_from(elem_count)?,
136                        self.0 as i32,
137                    )
138                };
139                drop(d_out_guard);
140                CudaStorage::wrap_cuda_slice(d_out, dev)
141            }
142            _ => unreachable!(),
143        };
144        Ok((dst, l1.shape().clone()))
145    }
146
147    #[cfg(feature = "metal")]
148    fn metal_fwd(
149        &self,
150        s1: &candle_core::MetalStorage,
151        l1: &Layout,
152    ) -> Result<(candle_core::MetalStorage, Shape)> {
153        if !l1.is_contiguous() {
154            candle_core::bail!("Input tensor s1 must be contiguous");
155        }
156
157        let command_buffer = s1.device().command_buffer()?;
158        command_buffer.set_label("bitwise-leftshift");
159
160        let device = s1.device();
161
162        let out_shape = l1.shape().clone();
163
164        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
165
166        crate::metal_kernels::call_bitwise_leftshift(
167            device.device(),
168            &command_buffer,
169            &crate::metal_kernels::Kernels::new(),
170            s1.dtype(),
171            s1.buffer(),
172            l1.start_offset(),
173            self.0 as u32,
174            out_shape.elem_count(),
175            &output,
176        )
177        .map_err(candle_core::Error::wrap)?;
178
179        let newstorage = candle_core::MetalStorage::new(
180            output,
181            device.clone(),
182            out_shape.elem_count(),
183            s1.dtype(),
184        );
185        Ok((newstorage, out_shape))
186    }
187}
188
189#[allow(dead_code)]
190pub trait LeftshiftOp {
191    fn leftshift(&self, n: usize) -> Result<Tensor>;
192}
193
194impl LeftshiftOp for Tensor {
195    fn leftshift(&self, n: usize) -> Result<Tensor> {
196        self.apply_op1_no_bwd(&Leftshift(n))
197    }
198}
199
200pub enum BitWiseBinaryOpEnum {
201    And,
202    Or,
203    Xor,
204}
205
206impl Display for BitWiseBinaryOpEnum {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        match self {
209            BitWiseBinaryOpEnum::And => write!(f, "And"),
210            BitWiseBinaryOpEnum::Or => write!(f, "Or"),
211            BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
212        }
213    }
214}
215
216pub enum BitWiseUnaryOpEnum {
217    Not,
218}
219
220impl Display for BitWiseUnaryOpEnum {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            BitWiseUnaryOpEnum::Not => write!(f, "Not"),
224        }
225    }
226}
227
228struct BitWise {
229    pub op: BitWiseBinaryOpEnum,
230}
231
232impl BitWise {
233    pub fn new(op: BitWiseBinaryOpEnum) -> Self {
234        Self { op }
235    }
236
237    fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
238        &self,
239        vs1: &[T],
240        vs2: &[T],
241    ) -> Vec<T> {
242        vs1.into_par_iter()
243            .zip_eq(vs2)
244            .map(|(v1, v2)| match self.op {
245                BitWiseBinaryOpEnum::And => *v1 & *v2,
246                BitWiseBinaryOpEnum::Or => *v1 | *v2,
247                BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
248            })
249            .collect()
250    }
251}
252
253impl CustomOp2 for BitWise {
254    fn name(&self) -> &'static str {
255        "bitwise"
256    }
257
258    fn cpu_fwd(
259        &self,
260        s1: &CpuStorage,
261        l1: &Layout,
262        s2: &CpuStorage,
263        l2: &Layout,
264    ) -> Result<(CpuStorage, Shape)> {
265        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
266            return Err(Error::ShapeMismatchBinaryOp {
267                lhs: l1.shape().clone(),
268                rhs: l2.shape().clone(),
269                op: "bitwise-op",
270            });
271        }
272        if s1.dtype() != s2.dtype() {
273            return Err(Error::DTypeMismatchBinaryOp {
274                lhs: s1.dtype(),
275                rhs: s2.dtype(),
276                op: "bitwise-op",
277            });
278        }
279        if !l1.is_contiguous() {
280            candle_core::bail!("Input tensor s1 must be contiguous");
281        }
282        if !l2.is_contiguous() {
283            candle_core::bail!("Input tensor s2 must be contiguous");
284        }
285
286        match s1 {
287            CpuStorage::U8(vs1) => {
288                let vs2 = s2.as_slice::<u8>().unwrap();
289                let vs1 = match l1.contiguous_offsets() {
290                    Some((a, b)) => &vs1[a..b],
291                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
292                };
293                let vs2 = match l2.contiguous_offsets() {
294                    Some((a, b)) => &vs2[a..b],
295                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
296                };
297                let result = self.bitwise(vs1, vs2);
298                let result = CpuStorage::U8(result);
299                Ok((result, l1.shape().clone()))
300            }
301            CpuStorage::U32(vs1) => {
302                let vs2 = s2.as_slice::<u32>().unwrap();
303                let vs1 = match l1.contiguous_offsets() {
304                    Some((a, b)) => &vs1[a..b],
305                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
306                };
307                let vs2 = match l2.contiguous_offsets() {
308                    Some((a, b)) => &vs2[a..b],
309                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
310                };
311                let result = self.bitwise(vs1, vs2);
312                let result = CpuStorage::U32(result);
313                Ok((result, l1.shape().clone()))
314            }
315            CpuStorage::I64(vs1) => {
316                let vs2 = s2.as_slice::<i64>().unwrap();
317                let vs1 = match l1.contiguous_offsets() {
318                    Some((a, b)) => &vs1[a..b],
319                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
320                };
321                let vs2 = match l2.contiguous_offsets() {
322                    Some((a, b)) => &vs2[a..b],
323                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
324                };
325                let result = self.bitwise(vs1, vs2);
326                let result = CpuStorage::I64(result);
327                Ok((result, l1.shape().clone()))
328            }
329            CpuStorage::I16(vs1) => {
330                let vs2 = s2.as_slice::<i16>().unwrap();
331                let vs1 = match l1.contiguous_offsets() {
332                    Some((a, b)) => &vs1[a..b],
333                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
334                };
335                let vs2 = match l2.contiguous_offsets() {
336                    Some((a, b)) => &vs2[a..b],
337                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
338                };
339                let result = self.bitwise(vs1, vs2);
340                let result = CpuStorage::I16(result);
341                Ok((result, l1.shape().clone()))
342            }
343            CpuStorage::I32(vs1) => {
344                let vs2 = s2.as_slice::<i32>().unwrap();
345                let vs1 = match l1.contiguous_offsets() {
346                    Some((a, b)) => &vs1[a..b],
347                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
348                };
349                let vs2 = match l2.contiguous_offsets() {
350                    Some((a, b)) => &vs2[a..b],
351                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
352                };
353                let result = self.bitwise(vs1, vs2);
354                let result = CpuStorage::I32(result);
355                Ok((result, l1.shape().clone()))
356            }
357            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
358        }
359    }
360
361    #[cfg(feature = "cuda")]
362    fn cuda_fwd(
363        &self,
364        s1: &CudaStorage,
365        l1: &Layout,
366        s2: &CudaStorage,
367        l2: &Layout,
368    ) -> Result<(CudaStorage, Shape)> {
369        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
370            return Err(Error::ShapeMismatchBinaryOp {
371                lhs: l1.shape().clone(),
372                rhs: l2.shape().clone(),
373                op: "bitwise-op",
374            });
375        }
376        if s1.dtype() != s2.dtype() {
377            return Err(Error::DTypeMismatchBinaryOp {
378                lhs: s1.dtype(),
379                rhs: s2.dtype(),
380                op: "bitwise-op",
381            });
382        }
383        if !l1.is_contiguous() {
384            candle_core::bail!("Input tensor s1 must be contiguous");
385        }
386        if !l2.is_contiguous() {
387            candle_core::bail!("Input tensor s2 must be contiguous");
388        }
389
390        let dev = s1.device().clone();
391        let (d_in1_ptr, d_in2_ptr, _d_in1_guard, _d_in2_guard, elem_count) = match s1.dtype() {
392            DType::U8 => {
393                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
394                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u8>()?, l2.start_offset());
395                let elem_count = l1.shape().elem_count();
396                (
397                    d_in1 as *const std::ffi::c_void,
398                    d_in2 as *const std::ffi::c_void,
399                    d_in1_guard,
400                    d_in2_guard,
401                    elem_count,
402                )
403            }
404            DType::U32 => {
405                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u32>()?, l1.start_offset());
406                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u32>()?, l2.start_offset());
407                let elem_count = l1.shape().elem_count();
408                (
409                    d_in1 as *const std::ffi::c_void,
410                    d_in2 as *const std::ffi::c_void,
411                    d_in1_guard,
412                    d_in2_guard,
413                    elem_count,
414                )
415            }
416            DType::I64 => {
417                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i64>()?, l1.start_offset());
418                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i64>()?, l2.start_offset());
419                let elem_count = l1.shape().elem_count();
420                (
421                    d_in1 as *const std::ffi::c_void,
422                    d_in2 as *const std::ffi::c_void,
423                    d_in1_guard,
424                    d_in2_guard,
425                    elem_count,
426                )
427            }
428            DType::I32 => {
429                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
430                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i32>()?, l2.start_offset());
431                let elem_count = l1.shape().elem_count();
432                (
433                    d_in1 as *const std::ffi::c_void,
434                    d_in2 as *const std::ffi::c_void,
435                    d_in1_guard,
436                    d_in2_guard,
437                    elem_count,
438                )
439            }
440            DType::I16 => {
441                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i16>()?, l1.start_offset());
442                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i16>()?, l2.start_offset());
443                let elem_count = l1.shape().elem_count();
444                (
445                    d_in1 as *const std::ffi::c_void,
446                    d_in2 as *const std::ffi::c_void,
447                    d_in1_guard,
448                    d_in2_guard,
449                    elem_count,
450                )
451            }
452            other => {
453                return Err(Error::UnsupportedDTypeForOp(other, "bitwise"));
454            }
455        };
456        let dst = match s1.dtype() {
457            DType::U8 => {
458                let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
459                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
460                unsafe {
461                    match self.op {
462                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
463                            d_in1_ptr,
464                            d_in2_ptr,
465                            d_out_ptr as *mut c_void,
466                            u32::try_from(elem_count)?,
467                        ),
468                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
469                            d_in1_ptr,
470                            d_in2_ptr,
471                            d_out_ptr as *mut c_void,
472                            u32::try_from(elem_count)?,
473                        ),
474                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
475                            d_in1_ptr,
476                            d_in2_ptr,
477                            d_out_ptr as *mut c_void,
478                            u32::try_from(elem_count)?,
479                        ),
480                    }
481                };
482                drop(d_out_guard);
483                CudaStorage::wrap_cuda_slice(d_out, dev)
484            }
485            DType::U32 => {
486                let d_out = unsafe { dev.alloc::<u32>(elem_count) }?;
487                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
488                unsafe {
489                    match self.op {
490                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
491                            d_in1_ptr,
492                            d_in2_ptr,
493                            d_out_ptr as *mut c_void,
494                            u32::try_from(elem_count)?,
495                        ),
496                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
497                            d_in1_ptr,
498                            d_in2_ptr,
499                            d_out_ptr as *mut c_void,
500                            u32::try_from(elem_count)?,
501                        ),
502                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
503                            d_in1_ptr,
504                            d_in2_ptr,
505                            d_out_ptr as *mut c_void,
506                            u32::try_from(elem_count)?,
507                        ),
508                    }
509                };
510                drop(d_out_guard);
511                CudaStorage::wrap_cuda_slice(d_out, dev)
512            }
513            DType::I64 => {
514                let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
515                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
516                unsafe {
517                    match self.op {
518                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
519                            d_in1_ptr,
520                            d_in2_ptr,
521                            d_out_ptr as *mut c_void,
522                            u32::try_from(elem_count)?,
523                        ),
524                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
525                            d_in1_ptr,
526                            d_in2_ptr,
527                            d_out_ptr as *mut c_void,
528                            u32::try_from(elem_count)?,
529                        ),
530                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
531                            d_in1_ptr,
532                            d_in2_ptr,
533                            d_out_ptr as *mut c_void,
534                            u32::try_from(elem_count)?,
535                        ),
536                    }
537                };
538                drop(d_out_guard);
539                CudaStorage::wrap_cuda_slice(d_out, dev)
540            }
541            DType::I32 => {
542                let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
543                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
544                unsafe {
545                    match self.op {
546                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
547                            d_in1_ptr,
548                            d_in2_ptr,
549                            d_out_ptr as *mut c_void,
550                            u32::try_from(elem_count)?,
551                        ),
552                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
553                            d_in1_ptr,
554                            d_in2_ptr,
555                            d_out_ptr as *mut c_void,
556                            u32::try_from(elem_count)?,
557                        ),
558                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
559                            d_in1_ptr,
560                            d_in2_ptr,
561                            d_out_ptr as *mut c_void,
562                            u32::try_from(elem_count)?,
563                        ),
564                    }
565                };
566                drop(d_out_guard);
567                CudaStorage::wrap_cuda_slice(d_out, dev)
568            }
569            _ => unreachable!(),
570        };
571        Ok((dst, l1.shape().clone()))
572    }
573
574    #[cfg(feature = "metal")]
575    fn metal_fwd(
576        &self,
577        s1: &candle_core::MetalStorage,
578        l1: &Layout,
579        s2: &candle_core::MetalStorage,
580        l2: &Layout,
581    ) -> Result<(candle_core::MetalStorage, Shape)> {
582        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
583            return Err(Error::ShapeMismatchBinaryOp {
584                lhs: l1.shape().clone(),
585                rhs: l2.shape().clone(),
586                op: "bitwise-op",
587            });
588        }
589        if s1.dtype() != s2.dtype() {
590            return Err(Error::DTypeMismatchBinaryOp {
591                lhs: s1.dtype(),
592                rhs: s2.dtype(),
593                op: "bitwise-op",
594            });
595        }
596        if !l1.is_contiguous() {
597            candle_core::bail!("Input tensor s1 must be contiguous");
598        }
599        if !l2.is_contiguous() {
600            candle_core::bail!("Input tensor s2 must be contiguous");
601        }
602
603        let command_buffer = s1.device().command_buffer()?;
604        command_buffer.set_label("bitwise-op");
605
606        let device = s1.device();
607
608        let out_shape = l1.shape().clone();
609
610        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
611
612        match self.op {
613            BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
614                device.device(),
615                &command_buffer,
616                &crate::metal_kernels::Kernels::new(),
617                s1.dtype(),
618                s1.buffer(),
619                s2.buffer(),
620                l1.start_offset() * s1.dtype().size_in_bytes(),
621                l2.start_offset() * s2.dtype().size_in_bytes(),
622                out_shape.elem_count(),
623                &output,
624            )
625            .map_err(candle_core::Error::wrap)?,
626            BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
627                device.device(),
628                &command_buffer,
629                &crate::metal_kernels::Kernels::new(),
630                s1.dtype(),
631                s1.buffer(),
632                s2.buffer(),
633                l1.start_offset() * s1.dtype().size_in_bytes(),
634                l2.start_offset() * s2.dtype().size_in_bytes(),
635                out_shape.elem_count(),
636                &output,
637            )
638            .map_err(candle_core::Error::wrap)?,
639            BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
640                device.device(),
641                &command_buffer,
642                &crate::metal_kernels::Kernels::new(),
643                s1.dtype(),
644                s1.buffer(),
645                s2.buffer(),
646                l1.start_offset() * s1.dtype().size_in_bytes(),
647                l2.start_offset() * s2.dtype().size_in_bytes(),
648                out_shape.elem_count(),
649                &output,
650            )
651            .map_err(candle_core::Error::wrap)?,
652        }
653
654        let newstorage = candle_core::MetalStorage::new(
655            output,
656            device.clone(),
657            out_shape.elem_count(),
658            s1.dtype(),
659        );
660        Ok((newstorage, out_shape))
661    }
662}
663
664struct BitWiseUnary {
665    pub op: BitWiseUnaryOpEnum,
666}
667
668impl BitWiseUnary {
669    pub fn new(op: BitWiseUnaryOpEnum) -> Self {
670        Self { op }
671    }
672
673    fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
674        vs1.into_par_iter()
675            .map(|v1| match self.op {
676                BitWiseUnaryOpEnum::Not => !*v1,
677            })
678            .collect()
679    }
680}
681
682impl CustomOp1 for BitWiseUnary {
683    fn name(&self) -> &'static str {
684        "bitwise-unary"
685    }
686
687    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
688        if !l1.is_contiguous() {
689            candle_core::bail!("Input tensor s1 must be contiguous");
690        }
691
692        match s1 {
693            CpuStorage::U8(vs1) => {
694                let vs1 = match l1.contiguous_offsets() {
695                    Some((a, b)) => &vs1[a..b],
696                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
697                };
698                let result = self.bitwise(vs1);
699                let result = CpuStorage::U8(result);
700                Ok((result, l1.shape().clone()))
701            }
702            CpuStorage::U32(vs1) => {
703                let vs1 = match l1.contiguous_offsets() {
704                    Some((a, b)) => &vs1[a..b],
705                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
706                };
707                let result = self.bitwise(vs1);
708                let result = CpuStorage::U32(result);
709                Ok((result, l1.shape().clone()))
710            }
711            CpuStorage::I64(vs1) => {
712                let vs1 = match l1.contiguous_offsets() {
713                    Some((a, b)) => &vs1[a..b],
714                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
715                };
716                let result = self.bitwise(vs1);
717                let result = CpuStorage::I64(result);
718                Ok((result, l1.shape().clone()))
719            }
720            CpuStorage::I16(vs1) => {
721                let vs1 = match l1.contiguous_offsets() {
722                    Some((a, b)) => &vs1[a..b],
723                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
724                };
725                let result = self.bitwise(vs1);
726                let result = CpuStorage::I16(result);
727                Ok((result, l1.shape().clone()))
728            }
729            CpuStorage::I32(vs1) => {
730                let vs1 = match l1.contiguous_offsets() {
731                    Some((a, b)) => &vs1[a..b],
732                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
733                };
734                let result = self.bitwise(vs1);
735                let result = CpuStorage::I32(result);
736                Ok((result, l1.shape().clone()))
737            }
738            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
739        }
740    }
741
742    #[cfg(feature = "cuda")]
743    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
744        todo!()
745    }
746
747    #[cfg(feature = "metal")]
748    fn metal_fwd(
749        &self,
750        s1: &candle_core::MetalStorage,
751        l1: &Layout,
752    ) -> Result<(candle_core::MetalStorage, Shape)> {
753        if !l1.is_contiguous() {
754            candle_core::bail!("Input tensor s1 must be contiguous");
755        }
756
757        let command_buffer = s1.device().command_buffer()?;
758        command_buffer.set_label("bitwise-unary-op");
759
760        let device = s1.device();
761
762        let out_shape = l1.shape().clone();
763
764        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
765
766        match self.op {
767            BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
768                device.device(),
769                &command_buffer,
770                &crate::metal_kernels::Kernels::new(),
771                s1.dtype(),
772                s1.buffer(),
773                l1.start_offset() * s1.dtype().size_in_bytes(),
774                out_shape.elem_count(),
775                &output,
776            )
777            .map_err(candle_core::Error::wrap)?,
778        }
779
780        let newstorage = candle_core::MetalStorage::new(
781            output,
782            device.clone(),
783            out_shape.elem_count(),
784            s1.dtype(),
785        );
786        Ok((newstorage, out_shape))
787    }
788}
789
790#[allow(dead_code)]
791pub trait BitWiseOp {
792    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
793    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
794    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
795    fn bitwise_not(&self) -> Result<Tensor>;
796}
797
798impl BitWiseOp for Tensor {
799    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
800        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
801    }
802
803    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
804        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
805    }
806
807    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
808        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
809    }
810
811    fn bitwise_not(&self) -> Result<Tensor> {
812        self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
813    }
814}
815
816// ────────────────────────────── ArgSort / Sort ────────────────────────────────
817
818#[allow(unused)]
819/// Configuration for an **argsort** (returns indices) operation.
820struct ArgSort {
821    axis: usize,
822}
823
824#[allow(unused)]
825/// Configuration for a **sort** (returns re‑ordered values) operation.
826struct Sort {
827    axis: usize,
828}
829
830impl CustomOp1 for ArgSort {
831    fn name(&self) -> &'static str {
832        "argsort"
833    }
834
835    // -------- CPU ------------------------------------------------------------
836    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
837        candle_core::bail!("ArgSort is not implemented for the CPU backend");
838    }
839
840    // -------- CUDA -----------------------------------------------------------
841    #[cfg(feature = "cuda")]
842    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
843        candle_core::bail!("ArgSort is not implemented for the CUDA backend");
844    }
845
846    // -------- Metal ----------------------------------------------------------
847    #[cfg(feature = "metal")]
848    fn metal_fwd(
849        &self,
850        s1: &candle_core::MetalStorage,
851        l1: &Layout,
852    ) -> Result<(candle_core::MetalStorage, Shape)> {
853        // Require contiguous input (same as other metal ops in this file)
854        if !l1.is_contiguous() {
855            candle_core::bail!("Input tensor s1 must be contiguous");
856        }
857
858        // Create a command‑buffer and label it for easy debugging in Xcode’s GPU frame‑capture
859        let command_buffer = s1.device().command_buffer()?;
860        command_buffer.set_label("argsort");
861
862        let device = s1.device();
863        let out_shape = l1.shape().clone();
864        let elem_count = out_shape.elem_count();
865
866        // Output buffer holds the sorted indices → always `U32`
867        let output = device.new_buffer(elem_count, candle_core::DType::U32, "argsort")?;
868
869        // ------------------------------------------------------------------
870        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
871        // ------------------------------------------------------------------
872        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
873
874        let dims = l1.dims();
875        let size_sorted_axis = dims[self.axis];
876        let n_rows = l1.shape().elem_count() / size_sorted_axis;
877
878        // Replicate the kernel’s internal block sizing to derive `n_blocks`
879        let tn = 4usize;
880        let mut bn = match size_sorted_axis.div_ceil(tn) {
881            v if v > 256 => 512,
882            v if v > 128 => 256,
883            v if v > 64 => 128,
884            v if v > 32 => 64,
885            _ => 32,
886        };
887        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
888            bn = 256;
889        }
890        let n_per_block = bn * tn;
891        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
892
893        // Borrow the buffers for this launch
894        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
895
896        // ------------------------------------------------------------------
897        // Build the unified SortArgs payload
898        // ------------------------------------------------------------------
899        let sort_args = crate::metal_kernels::SortArgs {
900            axis: self.axis,
901            shape: l1.dims(),
902            strides: l1.stride(),
903            out_shape: l1.dims(), // same as input for argsort
904            out_strides: l1.stride(),
905            in_contiguous: l1.is_contiguous(),
906            in_ty: s1.dtype(),
907            out_ty: candle_core::DType::U32,
908            src: s1.buffer(),
909            src_offset: l1.start_offset(), // element offset
910            dst: &output,
911            bn,
912            tn,
913            n_blocks,
914        };
915
916        // Launch the Metal kernel via the new API
917        crate::metal_kernels::call_argsort(
918            device.device(), // &metal::Device
919            &command_buffer, // impl EncoderProvider
920            &crate::metal_kernels::Kernels::new(),
921            &sort_args,
922            &scratch,
923        )
924        .map_err(candle_core::Error::wrap)?;
925
926        // Wrap and return as a new MetalStorage
927        let newstorage = candle_core::MetalStorage::new(
928            output,
929            device.clone(),
930            elem_count,
931            candle_core::DType::U32,
932        );
933        Ok((newstorage, out_shape))
934    }
935}
936
937impl CustomOp1 for Sort {
938    fn name(&self) -> &'static str {
939        "sort"
940    }
941
942    // -------- CPU ------------------------------------------------------------
943    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
944        candle_core::bail!("Sort is not implemented for the CPU backend");
945    }
946
947    // -------- CUDA -----------------------------------------------------------
948    #[cfg(feature = "cuda")]
949    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
950        candle_core::bail!("Sort is not implemented for the CUDA backend");
951    }
952
953    // -------- Metal ----------------------------------------------------------
954    #[cfg(feature = "metal")]
955    fn metal_fwd(
956        &self,
957        s1: &candle_core::MetalStorage,
958        l1: &Layout,
959    ) -> Result<(candle_core::MetalStorage, Shape)> {
960        // Require contiguous input (same as other metal ops in this file)
961        if !l1.is_contiguous() {
962            candle_core::bail!("Input tensor s1 must be contiguous");
963        }
964
965        // Create a command‑buffer and label it for easy debugging in Xcode’s GPU frame‑capture
966        let command_buffer = s1.device().command_buffer()?;
967        command_buffer.set_label("sort");
968
969        let device = s1.device();
970        let out_shape = l1.shape().clone();
971        let elem_count = out_shape.elem_count();
972
973        // Output buffer keeps the same dtype as the input (these are the reordered values)
974        let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
975
976        // ------------------------------------------------------------------
977        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
978        // ------------------------------------------------------------------
979        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
980
981        let dims = l1.dims();
982        let size_sorted_axis = dims[self.axis];
983        let n_rows = l1.shape().elem_count() / size_sorted_axis;
984
985        // Replicate the kernel’s internal block sizing to derive `n_blocks`
986        let tn = 4usize;
987        let mut bn = match size_sorted_axis.div_ceil(tn) {
988            v if v > 256 => 512,
989            v if v > 128 => 256,
990            v if v > 64 => 128,
991            v if v > 32 => 64,
992            _ => 32,
993        };
994        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
995            bn = 256;
996        }
997        let n_per_block = bn * tn;
998        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
999
1000        // Borrow the buffers for this launch
1001        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
1002
1003        // ------------------------------------------------------------------
1004        // Build the unified SortArgs payload
1005        // ------------------------------------------------------------------
1006        let sort_args = crate::metal_kernels::SortArgs {
1007            axis: self.axis,
1008            shape: l1.dims(),
1009            strides: l1.stride(),
1010            out_shape: l1.dims(), // same shape for value sort
1011            out_strides: l1.stride(),
1012            in_contiguous: l1.is_contiguous(),
1013            in_ty: s1.dtype(),
1014            out_ty: s1.dtype(),
1015            src: s1.buffer(),
1016            src_offset: l1.start_offset(), // element offset
1017            dst: &output,
1018            bn,
1019            tn,
1020            n_blocks,
1021        };
1022
1023        // Launch the Metal kernel via the new API
1024        crate::metal_kernels::call_sort(
1025            device.device(), // &metal::Device
1026            &command_buffer, // impl EncoderProvider
1027            &crate::metal_kernels::Kernels::new(),
1028            &sort_args,
1029            &scratch,
1030        )
1031        .map_err(candle_core::Error::wrap)?;
1032
1033        // Wrap and return as a new MetalStorage
1034        let newstorage =
1035            candle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1036        Ok((newstorage, out_shape))
1037    }
1038}
1039
1040/// Extension trait adding `argsort` / `sort` convenience calls on `Tensor`.
1041pub trait SortOp {
1042    /// Returns the indices that would (ascending) sort the tensor along `axis`.
1043    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1044    /// Returns the tensor's values (ascending) sorted along `axis`.
1045    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1046}
1047
1048impl SortOp for Tensor {
1049    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1050        if self.device().is_cpu() || self.device().is_cuda() {
1051            return self.arg_sort_last_dim(true);
1052        }
1053        self.apply_op1_no_bwd(&ArgSort {
1054            axis: axis.to_index(self.shape(), "argsort")?,
1055        })
1056    }
1057
1058    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1059        if self.device().is_cpu() || self.device().is_cuda() {
1060            return Ok(self.sort_last_dim(true)?.0);
1061        }
1062        self.apply_op1_no_bwd(&Sort {
1063            axis: axis.to_index(self.shape(), "sort")?,
1064        })
1065    }
1066}
1067
1068struct NonZero;
1069
1070impl NonZero {
1071    // Sequential version
1072    fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1073        let n = layout.dims().len();
1074        let mut result = Vec::new();
1075        let mut indices = vec![0u32; n];
1076        for (i, v) in vs.iter().enumerate() {
1077            if !v.is_zero() {
1078                let mut idx = i;
1079                for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1080                    let d = idx % dim;
1081                    indices[dim_index] = u32::try_from(d).unwrap();
1082                    idx /= dim;
1083                }
1084                result.extend_from_slice(&indices);
1085            }
1086        }
1087        result
1088    }
1089}
1090
1091#[cfg(feature = "cuda")]
1092fn count_nonzero_cuda(
1093    dtype: candle_core::DType,
1094    d_in: *const c_void,
1095    n: u32,
1096    stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1097) -> u32 {
1098    unsafe {
1099        match dtype {
1100            candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1101            candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1102            candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1103            candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1104            candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1105            candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1106            candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1107            candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1108            candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1109            _ => unreachable!(),
1110        }
1111    }
1112}
1113
1114#[allow(clippy::too_many_arguments)]
1115#[cfg(feature = "cuda")]
1116fn nonzero_cuda(
1117    dtype: candle_core::DType,
1118    d_in: *const c_void,
1119    n: u32,
1120    num_nonzero: u32,
1121    dims: *const c_void,
1122    num_dims: u32,
1123    d_out: *mut c_void,
1124    stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1125) {
1126    unsafe {
1127        match dtype {
1128            candle_core::DType::U8 => {
1129                ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1130            }
1131            candle_core::DType::U32 => {
1132                ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1133            }
1134            candle_core::DType::I64 => {
1135                ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1136            }
1137            candle_core::DType::I32 => {
1138                ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1139            }
1140            candle_core::DType::I16 => {
1141                ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1142            }
1143            candle_core::DType::BF16 => {
1144                ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1145            }
1146            candle_core::DType::F16 => {
1147                ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1148            }
1149            candle_core::DType::F32 => {
1150                ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1151            }
1152            candle_core::DType::F64 => {
1153                ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1154            }
1155            _ => unreachable!(),
1156        }
1157    }
1158}
1159
1160impl CustomOp1 for NonZero {
1161    fn name(&self) -> &'static str {
1162        "nonzero"
1163    }
1164
1165    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1166        if !layout.is_contiguous() {
1167            return Err(Error::RequiresContiguous { op: "nonzero" });
1168        }
1169        let result = match storage {
1170            candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1171            candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1172            candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1173            candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1174            candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1175            candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1176            candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1177            candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1178            candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1179            _ => unreachable!(),
1180        };
1181        let index_len = layout.dims().len();
1182        let result_len = result.len() / index_len;
1183        let result = CpuStorage::U32(result);
1184        let shape = Shape::from_dims(&[result_len, index_len]);
1185        Ok((result, shape))
1186    }
1187
1188    #[cfg(feature = "cuda")]
1189    fn cuda_fwd(
1190        &self,
1191        storage: &candle_core::CudaStorage,
1192        layout: &Layout,
1193    ) -> Result<(candle_core::CudaStorage, Shape)> {
1194        if !layout.is_contiguous() {
1195            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1196        }
1197        let dev = storage.device().clone();
1198        let (d_in, _d_in_guard) = match storage.dtype() {
1199            candle_core::DType::U8 => {
1200                let slice = storage.as_cuda_slice::<u8>()?;
1201                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1202                (d_in as *const std::ffi::c_void, d_in_guard)
1203            }
1204            candle_core::DType::U32 => {
1205                let slice = storage.as_cuda_slice::<u32>()?;
1206                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1207                (d_in as *const std::ffi::c_void, d_in_guard)
1208            }
1209            candle_core::DType::I32 => {
1210                let slice = storage.as_cuda_slice::<i32>()?;
1211                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1212                (d_in as *const std::ffi::c_void, d_in_guard)
1213            }
1214            candle_core::DType::I16 => {
1215                let slice = storage.as_cuda_slice::<i16>()?;
1216                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1217                (d_in as *const std::ffi::c_void, d_in_guard)
1218            }
1219            candle_core::DType::I64 => {
1220                let slice = storage.as_cuda_slice::<i64>()?;
1221                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1222                (d_in as *const std::ffi::c_void, d_in_guard)
1223            }
1224            candle_core::DType::BF16 => {
1225                let slice = storage.as_cuda_slice::<half::bf16>()?;
1226                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1227                (d_in as *const std::ffi::c_void, d_in_guard)
1228            }
1229            candle_core::DType::F16 => {
1230                let slice = storage.as_cuda_slice::<half::f16>()?;
1231                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1232                (d_in as *const std::ffi::c_void, d_in_guard)
1233            }
1234            candle_core::DType::F32 => {
1235                let slice = storage.as_cuda_slice::<f32>()?;
1236                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1237                (d_in as *const std::ffi::c_void, d_in_guard)
1238            }
1239            candle_core::DType::F64 => {
1240                let slice = storage.as_cuda_slice::<f64>()?;
1241                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1242                (d_in as *const std::ffi::c_void, d_in_guard)
1243            }
1244            _ => unreachable!(),
1245        };
1246        let n = layout.shape().elem_count();
1247
1248        let num_nonzero = count_nonzero_cuda(
1249            storage.dtype(),
1250            d_in,
1251            u32::try_from(n)?,
1252            dev.cuda_stream().cu_stream(),
1253        );
1254        let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1255            .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1256        if num_nonzero != 0 {
1257            let (d_out, _d_out_guard) = d_out.device_ptr(d_out.stream());
1258            let dims = layout
1259                .dims()
1260                .iter()
1261                .map(|&x| u32::try_from(x).unwrap())
1262                .collect::<Vec<u32>>();
1263            let mut d_dims = unsafe { dev.alloc::<u32>(dims.len()) }?;
1264            dev.memcpy_htod(&dims, &mut d_dims)?;
1265            let (d_dims_ptr, _d_dims_guard) = d_dims.device_ptr(d_dims.stream());
1266            nonzero_cuda(
1267                storage.dtype(),
1268                d_in,
1269                u32::try_from(n)?,
1270                num_nonzero,
1271                d_dims_ptr as *const c_void,
1272                u32::try_from(layout.dims().len())?,
1273                d_out as *mut c_void,
1274                dev.cuda_stream().cu_stream(),
1275            );
1276        }
1277        let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1278        let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
1279        Ok((dst, shape))
1280    }
1281}
1282
1283pub trait NonZeroOp {
1284    fn nonzero(&self) -> Result<Tensor>;
1285}
1286
1287impl NonZeroOp for Tensor {
1288    #[cfg(feature = "metal")]
1289    fn nonzero(&self) -> Result<Tensor> {
1290        if !self.is_contiguous() {
1291            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1292        }
1293        let original_device = self.device();
1294        self.to_device(&candle_core::Device::Cpu)?
1295            .apply_op1_no_bwd(&NonZero)?
1296            .to_device(original_device)
1297    }
1298
1299    #[cfg(not(feature = "metal"))]
1300    fn nonzero(&self) -> Result<Tensor> {
1301        if !self.is_contiguous() {
1302            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1303        }
1304        self.apply_op1_no_bwd(&NonZero)
1305    }
1306}
1307
1308struct CumSum {
1309    inclusive: bool,
1310    reverse: bool,
1311    axis: usize,
1312}
1313
1314impl CustomOp1 for CumSum {
1315    fn name(&self) -> &'static str {
1316        "cumsum"
1317    }
1318
1319    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1320        use std::ops::Add;
1321        if !l1.is_contiguous() {
1322            candle_core::bail!("Input tensor s1 must be contiguous");
1323        }
1324        let dims = l1.dims();
1325        let axis = self.axis;
1326        let axis_len = dims[axis];
1327        let (start, end) = l1
1328            .contiguous_offsets()
1329            .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1330
1331        // helper to execute scan for a slice of T
1332        macro_rules! scan_block {
1333            ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1334                let vs: &[$ty] = $vt;
1335                let input = &vs[start..end];
1336                let count = input.len() / axis_len;
1337                let mut result = Vec::<$ty>::with_capacity(input.len());
1338                if !self.reverse {
1339                    if self.inclusive {
1340                        for block in 0..count {
1341                            let base = block * axis_len;
1342                            let mut sum = input[base];
1343                            result.push(sum);
1344                            for j in 1..axis_len {
1345                                sum = sum.$add(input[base + j]);
1346                                result.push(sum);
1347                            }
1348                        }
1349                    } else {
1350                        let init: $ty = $init;
1351                        for block in 0..count {
1352                            let base = block * axis_len;
1353                            let mut sum = init;
1354                            for j in 0..axis_len {
1355                                result.push(sum);
1356                                sum = sum.$add(input[base + j]);
1357                            }
1358                        }
1359                    }
1360                } else {
1361                    if self.inclusive {
1362                        for block in 0..count {
1363                            let base = block * axis_len;
1364                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1365                            let mut sum = input[base + axis_len - 1];
1366                            temp.push(sum);
1367                            for k in 1..axis_len {
1368                                let idx = axis_len - 1 - k;
1369                                sum = sum.$add(input[base + idx]);
1370                                temp.push(sum);
1371                            }
1372                            temp.reverse();
1373                            result.extend(temp);
1374                        }
1375                    } else {
1376                        let init: $ty = $init;
1377                        for block in 0..count {
1378                            let base = block * axis_len;
1379                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1380                            let mut sum = init;
1381                            for k in 0..axis_len {
1382                                let idx = axis_len - 1 - k;
1383                                temp.push(sum);
1384                                sum = sum.$add(input[base + idx]);
1385                            }
1386                            temp.reverse();
1387                            result.extend(temp);
1388                        }
1389                    }
1390                }
1391                result
1392            }};
1393        }
1394        match s1 {
1395            CpuStorage::U8(vs) => {
1396                let result = scan_block!(vs, u8, wrapping_add, 0u8);
1397                Ok((CpuStorage::U8(result), l1.shape().clone()))
1398            }
1399            CpuStorage::I16(vs) => {
1400                let result = scan_block!(vs, i16, add, 0i16);
1401                Ok((CpuStorage::I16(result), l1.shape().clone()))
1402            }
1403            CpuStorage::U32(vs) => {
1404                let result = scan_block!(vs, u32, wrapping_add, 0u32);
1405                Ok((CpuStorage::U32(result), l1.shape().clone()))
1406            }
1407            CpuStorage::I32(vs) => {
1408                let result = scan_block!(vs, i32, add, 0i32);
1409                Ok((CpuStorage::I32(result), l1.shape().clone()))
1410            }
1411            CpuStorage::I64(vs) => {
1412                let result = scan_block!(vs, i64, add, 0i64);
1413                Ok((CpuStorage::I64(result), l1.shape().clone()))
1414            }
1415            CpuStorage::F32(vs) => {
1416                let result = scan_block!(vs, f32, add, 0.0f32);
1417                Ok((CpuStorage::F32(result), l1.shape().clone()))
1418            }
1419            CpuStorage::F64(vs) => {
1420                let result = scan_block!(vs, f64, add, 0.0f64);
1421                Ok((CpuStorage::F64(result), l1.shape().clone()))
1422            }
1423            _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1424        }
1425    }
1426
1427    #[cfg(feature = "cuda")]
1428    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1429        todo!()
1430    }
1431
1432    #[cfg(feature = "metal")]
1433    fn metal_fwd(
1434        &self,
1435        s1: &candle_core::MetalStorage,
1436        l1: &Layout,
1437    ) -> Result<(candle_core::MetalStorage, Shape)> {
1438        use crate::metal_kernels::ScanType;
1439
1440        let command_buffer = s1.device().command_buffer()?;
1441        command_buffer.set_label("cumsum");
1442
1443        let device = s1.device();
1444
1445        let out_shape = l1.shape().clone();
1446
1447        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1448
1449        crate::metal_kernels::call_scan(
1450            device.device(),
1451            &command_buffer,
1452            &crate::metal_kernels::Kernels::new(),
1453            s1.dtype(),
1454            ScanType::Sum,
1455            s1.buffer(),
1456            l1.start_offset() * s1.dtype().size_in_bytes(),
1457            self.axis,
1458            l1.dims(),
1459            l1.stride(),
1460            self.reverse,
1461            self.inclusive,
1462            &output,
1463        )
1464        .map_err(candle_core::Error::wrap)?;
1465
1466        let newstorage = candle_core::MetalStorage::new(
1467            output,
1468            device.clone(),
1469            out_shape.elem_count(),
1470            s1.dtype(),
1471        );
1472        Ok((newstorage, out_shape))
1473    }
1474}
1475
1476#[allow(dead_code)]
1477pub trait CumSumOp {
1478    /// inclusive = false, reverse = false
1479    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1480
1481    fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1482        -> Result<Tensor>;
1483}
1484
1485impl CumSumOp for Tensor {
1486    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1487        self.fast_cumsum_config(axis, false, false)
1488    }
1489
1490    fn fast_cumsum_config<D: Dim>(
1491        &self,
1492        axis: D,
1493        inclusive: bool,
1494        reverse: bool,
1495    ) -> Result<Tensor> {
1496        self.apply_op1_no_bwd(&CumSum {
1497            inclusive,
1498            reverse,
1499            axis: axis.to_index(self.shape(), "cumsum")?,
1500        })
1501    }
1502}
1503
1504mod tests {
1505    #[test]
1506    fn test_cumsum_exclusive_forward_cpu() {
1507        use crate::utils::ops::CumSumOp;
1508        use candle_core::Tensor;
1509        let device = candle_core::Device::Cpu;
1510        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1511        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1512        assert_eq!(b, [0, 1, 3, 6]);
1513    }
1514
1515    #[test]
1516    fn test_cumsum_inclusive_forward_cpu() {
1517        use crate::utils::ops::CumSumOp;
1518        use candle_core::Tensor;
1519        let device = candle_core::Device::Cpu;
1520        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1521        let b = a
1522            .fast_cumsum_config(0, true, false)
1523            .unwrap()
1524            .to_vec1::<i64>()
1525            .unwrap();
1526        assert_eq!(b, [1, 3, 6, 10]);
1527    }
1528
1529    #[test]
1530    fn test_cumsum_exclusive_reverse_cpu() {
1531        use crate::utils::ops::CumSumOp;
1532        use candle_core::Tensor;
1533        let device = candle_core::Device::Cpu;
1534        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1535        let b = a
1536            .fast_cumsum_config(0, false, true)
1537            .unwrap()
1538            .to_vec1::<i64>()
1539            .unwrap();
1540        assert_eq!(b, [9, 7, 4, 0]);
1541    }
1542
1543    #[test]
1544    fn test_cumsum_inclusive_reverse_cpu() {
1545        use crate::utils::ops::CumSumOp;
1546        use candle_core::Tensor;
1547        let device = candle_core::Device::Cpu;
1548        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1549        let b = a
1550            .fast_cumsum_config(0, true, true)
1551            .unwrap()
1552            .to_vec1::<i64>()
1553            .unwrap();
1554        assert_eq!(b, [10, 9, 7, 4]);
1555    }
1556
1557    #[cfg(feature = "metal")]
1558    #[test]
1559    fn test_cumsum_exclusive_forward_metal() {
1560        use crate::utils::ops::CumSumOp;
1561        use candle_core::Tensor;
1562        let device = candle_core::Device::new_metal(0).unwrap();
1563        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1564        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
1565        assert_eq!(b, [0, 1, 3, 6]);
1566    }
1567
1568    #[cfg(feature = "metal")]
1569    #[test]
1570    fn test_cumsum_inclusive_forward_metal() {
1571        use crate::utils::ops::CumSumOp;
1572        use candle_core::Tensor;
1573        let device = candle_core::Device::new_metal(0).unwrap();
1574        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1575        let b = a
1576            .fast_cumsum_config(0, true, false)
1577            .unwrap()
1578            .to_vec1::<i64>()
1579            .unwrap();
1580        assert_eq!(b, [1, 3, 6, 10]);
1581    }
1582
1583    #[cfg(feature = "metal")]
1584    #[test]
1585    fn test_cumsum_exclusive_reverse_metal() {
1586        use crate::utils::ops::CumSumOp;
1587        use candle_core::Tensor;
1588        let device = candle_core::Device::new_metal(0).unwrap();
1589        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1590        let b = a
1591            .fast_cumsum_config(0, false, true)
1592            .unwrap()
1593            .to_vec1::<i64>()
1594            .unwrap();
1595        assert_eq!(b, [9, 7, 4, 0]);
1596    }
1597
1598    #[cfg(feature = "metal")]
1599    #[test]
1600    fn test_cumsum_inclusive_reverse_metal() {
1601        use crate::utils::ops::CumSumOp;
1602        use candle_core::Tensor;
1603        let device = candle_core::Device::new_metal(0).unwrap();
1604        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
1605        let b = a
1606            .fast_cumsum_config(0, true, true)
1607            .unwrap()
1608            .to_vec1::<i64>()
1609            .unwrap();
1610        assert_eq!(b, [10, 9, 7, 4]);
1611    }
1612
1613    #[test]
1614    fn test_nonzero_cpu() {
1615        use crate::utils::ops::NonZeroOp;
1616        use candle_core::Tensor;
1617        let device = candle_core::Device::Cpu;
1618        let a = Tensor::from_vec(
1619            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1620            &[2, 4],
1621            &device,
1622        )
1623        .unwrap();
1624        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1625        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1626    }
1627
1628    #[cfg(feature = "cuda")]
1629    #[test]
1630    fn test_nonzero_cuda() {
1631        use crate::utils::ops::NonZeroOp;
1632        use candle_core::Tensor;
1633        let device = candle_core::Device::new_cuda(0).unwrap();
1634        let a = Tensor::from_vec(
1635            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
1636            &[2, 4],
1637            &device,
1638        )
1639        .unwrap();
1640        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
1641        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
1642    }
1643
1644    #[test]
1645    fn test_bitwise_and_cpu() {
1646        use crate::utils::ops::BitWiseOp;
1647        use candle_core::Tensor;
1648        let device = candle_core::Device::Cpu;
1649        let a =
1650            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1651        let b =
1652            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1653        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1654        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
1655    }
1656
1657    #[cfg(feature = "cuda")]
1658    #[test]
1659    fn test_bitwise_and_cuda() {
1660        use crate::utils::ops::BitWiseOp;
1661        use candle_core::Tensor;
1662        let device = candle_core::Device::new_cuda(0).unwrap();
1663        let a =
1664            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1665        let b =
1666            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
1667        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
1668        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
1669    }
1670
1671    #[test]
1672    fn test_bitwise_or_cpu() {
1673        use crate::utils::ops::BitWiseOp;
1674        use candle_core::Tensor;
1675        let device = candle_core::Device::Cpu;
1676        let a =
1677            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1678        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1679        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1680        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1681    }
1682
1683    #[cfg(feature = "cuda")]
1684    #[test]
1685    fn test_bitwise_or_cuda() {
1686        use crate::utils::ops::BitWiseOp;
1687        use candle_core::Tensor;
1688        let device = candle_core::Device::new_cuda(0).unwrap();
1689        let a =
1690            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1691        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1692        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
1693        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1694    }
1695
1696    #[test]
1697    fn test_bitwise_xor_cpu() {
1698        use crate::utils::ops::BitWiseOp;
1699        use candle_core::Tensor;
1700        let device = candle_core::Device::Cpu;
1701        let a =
1702            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1703        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1704        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1705        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1706    }
1707
1708    #[cfg(feature = "cuda")]
1709    #[test]
1710    fn test_bitwise_xor_cuda() {
1711        use crate::utils::ops::BitWiseOp;
1712        use candle_core::Tensor;
1713        let device = candle_core::Device::new_cuda(0).unwrap();
1714        let a =
1715            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1716        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1717        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1718        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1719    }
1720
1721    #[test]
1722    fn test_nonzero_and() {
1723        use crate::utils::ops::{BitWiseOp, NonZeroOp};
1724        use candle_core::{Device, Tensor};
1725
1726        let input1 = Tensor::from_vec(
1727            vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1728            (10,),
1729            &Device::Cpu,
1730        )
1731        .unwrap();
1732        let input2 = Tensor::from_vec(
1733            vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1734            (10,),
1735            &Device::Cpu,
1736        )
1737        .unwrap();
1738        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1739
1740        let lt = input.lt(0.0).unwrap();
1741        let gt = input.gt(-10.0).unwrap();
1742        let res = lt
1743            .bitwise_and(&gt)
1744            .unwrap()
1745            .nonzero()
1746            .unwrap()
1747            .to_vec2::<u32>()
1748            .unwrap();
1749
1750        assert_eq!(
1751            res,
1752            [
1753                [0, 3],
1754                [0, 4],
1755                [0, 5],
1756                [0, 6],
1757                [1, 0],
1758                [1, 3],
1759                [1, 5],
1760                [1, 6]
1761            ]
1762        );
1763    }
1764
1765    #[cfg(feature = "cuda")]
1766    #[test]
1767    fn nonzero_and_cuda() {
1768        use crate::utils::ops::{BitWiseOp, NonZeroOp};
1769        use candle_core::{Device, Tensor};
1770
1771        let device = Device::new_cuda(0).unwrap();
1772        let input1 =
1773            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1774        let input2 =
1775            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1776        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1777
1778        let lt = input.lt(0.0).unwrap();
1779        let gt = input.gt(-10.0).unwrap();
1780        let res = lt
1781            .bitwise_and(&gt)
1782            .unwrap()
1783            .nonzero()
1784            .unwrap()
1785            .to_vec2::<u32>()
1786            .unwrap();
1787
1788        assert_eq!(
1789            res,
1790            [
1791                [0, 3],
1792                [0, 4],
1793                [0, 5],
1794                [0, 6],
1795                [1, 0],
1796                [1, 3],
1797                [1, 5],
1798                [1, 6]
1799            ]
1800        );
1801    }
1802
1803    #[test]
1804    fn test_bitpack_8bit_cpu() {
1805        use crate::HqqBits;
1806        use candle_core::{Device, Tensor};
1807        let bits = HqqBits::Eight;
1808        let device = Device::Cpu;
1809        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1810        let c = bits.bitpack_type()(wq.clone())
1811            .unwrap()
1812            .to_vec2::<u8>()
1813            .unwrap();
1814        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1815    }
1816
1817    #[cfg(feature = "cuda")]
1818    #[test]
1819    fn test_bitpack_8bit_cuda() {
1820        use crate::HqqBits;
1821        use candle_core::DType;
1822        use candle_core::{Device, Tensor};
1823        let bits = HqqBits::Eight;
1824        let device = Device::new_cuda(0).unwrap();
1825        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1826        let c = bits.bitpack_type()(wq.clone())
1827            .unwrap()
1828            .to_dtype(DType::U8)
1829            .unwrap()
1830            .to_vec2::<u8>()
1831            .unwrap();
1832        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1833    }
1834
1835    #[cfg(feature = "metal")]
1836    #[test]
1837    fn test_bitpack_8bit_metal() {
1838        use crate::HqqBits;
1839        use candle_core::{Device, Tensor};
1840        let bits = HqqBits::Eight;
1841        let device = Device::new_metal(0).unwrap();
1842        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1843        let c = bits.bitpack_type()(wq.clone())
1844            .unwrap()
1845            .to_vec2::<u8>()
1846            .unwrap();
1847        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1848    }
1849
1850    #[test]
1851    fn test_bitpack_4bit() {
1852        use crate::HqqBits;
1853        use candle_core::{Device, Tensor};
1854        let bits = HqqBits::Four;
1855        let device = Device::Cpu;
1856        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1857        let c = bits.bitpack_type()(wq.clone())
1858            .unwrap()
1859            .to_vec2::<u8>()
1860            .unwrap();
1861        assert_eq!(c, [[19, 36]]);
1862    }
1863
1864    #[cfg(feature = "cuda")]
1865    #[test]
1866    fn test_bitpack_4bit_cuda() {
1867        use crate::HqqBits;
1868        use candle_core::{Device, Tensor};
1869        let bits = HqqBits::Four;
1870        let device = Device::new_cuda(0).unwrap();
1871        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1872        let c = bits.bitpack_type()(wq.clone())
1873            .unwrap()
1874            .to_vec2::<u8>()
1875            .unwrap();
1876        assert_eq!(c, [[19, 36]]);
1877    }
1878
1879    #[cfg(feature = "metal")]
1880    #[test]
1881    fn test_bitpack_4bit_metal() {
1882        use crate::HqqBits;
1883        use candle_core::{Device, Tensor};
1884        let bits = HqqBits::Four;
1885        let device = Device::new_metal(0).unwrap();
1886        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1887        let c = bits.bitpack_type()(wq.clone())
1888            .unwrap()
1889            .to_vec2::<u8>()
1890            .unwrap();
1891        assert_eq!(c, [[19, 36]]);
1892    }
1893    // ─────────────────────────────── Sort / ArgSort ────────────────────────────────
1894    #[cfg(feature = "metal")]
1895    #[test]
1896    fn test_sort_and_argsort_vector_metal() {
1897        use crate::utils::ops::SortOp;
1898        use candle_core::Tensor;
1899
1900        let device = candle_core::Device::new_metal(0).unwrap();
1901        let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
1902
1903        // sort (ascending)
1904        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1905        assert_eq!(sorted, [1, 2, 3, 4]);
1906
1907        // argsort (ascending indices)
1908        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1909        assert_eq!(idx, [1, 3, 0, 2]);
1910    }
1911
1912    #[cfg(feature = "metal")]
1913    #[test]
1914    fn test_sort_and_argsort_matrix_axis1_metal() {
1915        use crate::utils::ops::SortOp;
1916        use candle_core::Tensor;
1917
1918        let device = candle_core::Device::new_metal(0).unwrap();
1919        // 2 × 3 matrix:
1920        // [[3, 1, 2],
1921        //  [0, 4, 5]]
1922        let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
1923
1924        // Sort along axis=1 (second dimension)
1925        let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
1926        assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
1927
1928        // ArgSort indices along axis=1
1929        let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
1930        assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
1931    }
1932
1933    // ─────────────────────────────── 2 048-element vector ────────────────────────────────
1934    #[cfg(feature = "metal")]
1935    #[test]
1936    fn test_sort_and_argsort_vector_2048_metal() {
1937        use crate::utils::ops::SortOp;
1938        use candle_core::Tensor;
1939
1940        const N: usize = 4096;
1941
1942        let device = candle_core::Device::new_metal(0).expect("Metal device");
1943
1944        // Create a descending vector [4095, 4094, …, 0]
1945        let vals: Vec<i32> = (0..N as i32).rev().collect();
1946        let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
1947
1948        // ---- sort (ascending) ---------------------------------------------------------
1949        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
1950        let expected: Vec<i32> = (0..N as i32).collect();
1951        assert_eq!(sorted, expected);
1952
1953        // ---- argsort (indices that would sort) ---------------------------------------
1954        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
1955        // Because the input is reversed, the correct indices are likewise reversed
1956        for (i, &v) in idx.iter().enumerate() {
1957            assert_eq!(v as usize, N - 1 - i);
1958        }
1959    }
1960}