mistralrs_quant/utils/
ops.rs

1use candle_core::{
2    backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3    Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8    fmt::Display,
9    ops::{BitAnd, BitOr, BitXor, Not, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::{ffi, slice_ptr};
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19#[cfg(feature = "metal")]
20use crate::metal_kernels::SortScratchCache; // re‑export for clarity
21#[cfg(feature = "metal")]
22use std::sync::OnceLock;
23
24#[cfg(feature = "metal")]
25static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
26
27struct Leftshift(usize);
28
29impl Leftshift {
30    fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
31        let offset = T::from_f64(self.0 as f64);
32        vs.into_par_iter().map(|v| *v << offset).collect()
33    }
34}
35
36impl CustomOp1 for Leftshift {
37    fn name(&self) -> &'static str {
38        "left"
39    }
40
41    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
42        match s1 {
43            CpuStorage::U8(vs1) => {
44                let vs1 = match l1.contiguous_offsets() {
45                    Some((a, b)) => &vs1[a..b],
46                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
47                };
48                let result = self.leftshift(vs1);
49                let result = CpuStorage::U8(result);
50                Ok((result, l1.shape().clone()))
51            }
52            CpuStorage::I16(vs1) => {
53                let vs1 = match l1.contiguous_offsets() {
54                    Some((a, b)) => &vs1[a..b],
55                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
56                };
57                let result = self.leftshift(vs1);
58                let result = CpuStorage::I16(result);
59                Ok((result, l1.shape().clone()))
60            }
61            CpuStorage::U32(vs1) => {
62                let vs1 = match l1.contiguous_offsets() {
63                    Some((a, b)) => &vs1[a..b],
64                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
65                };
66                let result = self.leftshift(vs1);
67                let result = CpuStorage::U32(result);
68                Ok((result, l1.shape().clone()))
69            }
70            CpuStorage::I64(vs1) => {
71                let vs1 = match l1.contiguous_offsets() {
72                    Some((a, b)) => &vs1[a..b],
73                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
74                };
75                let result = self.leftshift(vs1);
76                let result = CpuStorage::I64(result);
77                Ok((result, l1.shape().clone()))
78            }
79            CpuStorage::I32(vs1) => {
80                let vs1 = match l1.contiguous_offsets() {
81                    Some((a, b)) => &vs1[a..b],
82                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
83                };
84                let result = self.leftshift(vs1);
85                let result = CpuStorage::I32(result);
86                Ok((result, l1.shape().clone()))
87            }
88            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "leftshift")),
89        }
90    }
91
92    #[cfg(feature = "cuda")]
93    fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
94        if !l1.is_contiguous() {
95            candle_core::bail!("Input tensor s1 must be contiguous");
96        }
97        let dev = s1.device().clone();
98        let (d_in1_ptr, _d_guard, elem_count) = match s1.dtype() {
99            DType::U8 => {
100                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
101                let elem_count = l1.shape().elem_count();
102                (d_in1 as *const c_void, d_in1_guard, elem_count)
103            }
104            DType::I32 => {
105                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
106                let elem_count = l1.shape().elem_count();
107                (d_in1 as *const c_void, d_in1_guard, elem_count)
108            }
109            other => {
110                return Err(Error::UnsupportedDTypeForOp(other, "leftshift"));
111            }
112        };
113        let dst = match s1.dtype() {
114            DType::U8 => {
115                let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
116                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
117                unsafe {
118                    ffi::leftshift_u8(
119                        d_in1_ptr,
120                        d_out_ptr as *mut std::ffi::c_void,
121                        u32::try_from(elem_count)?,
122                        self.0 as i32,
123                    )
124                };
125                drop(d_out_guard);
126                CudaStorage::wrap_cuda_slice(d_out, dev)
127            }
128            DType::I32 => {
129                let d_out = unsafe { dev.alloc::<i32>(elem_count) }?;
130                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
131                unsafe {
132                    ffi::leftshift_i32(
133                        d_in1_ptr,
134                        d_out_ptr as *mut std::ffi::c_void,
135                        u32::try_from(elem_count)?,
136                        self.0 as i32,
137                    )
138                };
139                drop(d_out_guard);
140                CudaStorage::wrap_cuda_slice(d_out, dev)
141            }
142            _ => unreachable!(),
143        };
144        Ok((dst, l1.shape().clone()))
145    }
146
147    #[cfg(feature = "metal")]
148    fn metal_fwd(
149        &self,
150        s1: &candle_core::MetalStorage,
151        l1: &Layout,
152    ) -> Result<(candle_core::MetalStorage, Shape)> {
153        if !l1.is_contiguous() {
154            candle_core::bail!("Input tensor s1 must be contiguous");
155        }
156
157        let encoder = s1.device().command_encoder()?;
158        encoder.set_label("bitwise-leftshift");
159
160        let device = s1.device();
161
162        let out_shape = l1.shape().clone();
163
164        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
165
166        crate::metal_kernels::call_bitwise_leftshift(
167            device.device(),
168            &encoder,
169            &crate::metal_kernels::Kernels::new(),
170            s1.dtype(),
171            s1.buffer(),
172            l1.start_offset(),
173            self.0 as u32,
174            out_shape.elem_count(),
175            &output,
176        )
177        .map_err(candle_core::Error::wrap)?;
178
179        let newstorage = candle_core::MetalStorage::new(
180            output,
181            device.clone(),
182            out_shape.elem_count(),
183            s1.dtype(),
184        );
185        Ok((newstorage, out_shape))
186    }
187}
188
189#[allow(dead_code)]
190pub trait LeftshiftOp {
191    fn leftshift(&self, n: usize) -> Result<Tensor>;
192}
193
194impl LeftshiftOp for Tensor {
195    fn leftshift(&self, n: usize) -> Result<Tensor> {
196        self.apply_op1_no_bwd(&Leftshift(n))
197    }
198}
199
200pub enum BitWiseBinaryOpEnum {
201    And,
202    Or,
203    Xor,
204}
205
206impl Display for BitWiseBinaryOpEnum {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        match self {
209            BitWiseBinaryOpEnum::And => write!(f, "And"),
210            BitWiseBinaryOpEnum::Or => write!(f, "Or"),
211            BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
212        }
213    }
214}
215
216pub enum BitWiseUnaryOpEnum {
217    Not,
218}
219
220impl Display for BitWiseUnaryOpEnum {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            BitWiseUnaryOpEnum::Not => write!(f, "Not"),
224        }
225    }
226}
227
228struct BitWise {
229    pub op: BitWiseBinaryOpEnum,
230}
231
232impl BitWise {
233    pub fn new(op: BitWiseBinaryOpEnum) -> Self {
234        Self { op }
235    }
236
237    fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
238        &self,
239        vs1: &[T],
240        vs2: &[T],
241    ) -> Vec<T> {
242        vs1.into_par_iter()
243            .zip_eq(vs2)
244            .map(|(v1, v2)| match self.op {
245                BitWiseBinaryOpEnum::And => *v1 & *v2,
246                BitWiseBinaryOpEnum::Or => *v1 | *v2,
247                BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
248            })
249            .collect()
250    }
251}
252
253impl CustomOp2 for BitWise {
254    fn name(&self) -> &'static str {
255        "bitwise"
256    }
257
258    fn cpu_fwd(
259        &self,
260        s1: &CpuStorage,
261        l1: &Layout,
262        s2: &CpuStorage,
263        l2: &Layout,
264    ) -> Result<(CpuStorage, Shape)> {
265        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
266            return Err(Error::ShapeMismatchBinaryOp {
267                lhs: l1.shape().clone(),
268                rhs: l2.shape().clone(),
269                op: "bitwise-op",
270            });
271        }
272        if s1.dtype() != s2.dtype() {
273            return Err(Error::DTypeMismatchBinaryOp {
274                lhs: s1.dtype(),
275                rhs: s2.dtype(),
276                op: "bitwise-op",
277            });
278        }
279        if !l1.is_contiguous() {
280            candle_core::bail!("Input tensor s1 must be contiguous");
281        }
282        if !l2.is_contiguous() {
283            candle_core::bail!("Input tensor s2 must be contiguous");
284        }
285
286        match s1 {
287            CpuStorage::U8(vs1) => {
288                let vs2 = s2.as_slice::<u8>().unwrap();
289                let vs1 = match l1.contiguous_offsets() {
290                    Some((a, b)) => &vs1[a..b],
291                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
292                };
293                let vs2 = match l2.contiguous_offsets() {
294                    Some((a, b)) => &vs2[a..b],
295                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
296                };
297                let result = self.bitwise(vs1, vs2);
298                let result = CpuStorage::U8(result);
299                Ok((result, l1.shape().clone()))
300            }
301            CpuStorage::U32(vs1) => {
302                let vs2 = s2.as_slice::<u32>().unwrap();
303                let vs1 = match l1.contiguous_offsets() {
304                    Some((a, b)) => &vs1[a..b],
305                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
306                };
307                let vs2 = match l2.contiguous_offsets() {
308                    Some((a, b)) => &vs2[a..b],
309                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
310                };
311                let result = self.bitwise(vs1, vs2);
312                let result = CpuStorage::U32(result);
313                Ok((result, l1.shape().clone()))
314            }
315            CpuStorage::I64(vs1) => {
316                let vs2 = s2.as_slice::<i64>().unwrap();
317                let vs1 = match l1.contiguous_offsets() {
318                    Some((a, b)) => &vs1[a..b],
319                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
320                };
321                let vs2 = match l2.contiguous_offsets() {
322                    Some((a, b)) => &vs2[a..b],
323                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
324                };
325                let result = self.bitwise(vs1, vs2);
326                let result = CpuStorage::I64(result);
327                Ok((result, l1.shape().clone()))
328            }
329            CpuStorage::I16(vs1) => {
330                let vs2 = s2.as_slice::<i16>().unwrap();
331                let vs1 = match l1.contiguous_offsets() {
332                    Some((a, b)) => &vs1[a..b],
333                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
334                };
335                let vs2 = match l2.contiguous_offsets() {
336                    Some((a, b)) => &vs2[a..b],
337                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
338                };
339                let result = self.bitwise(vs1, vs2);
340                let result = CpuStorage::I16(result);
341                Ok((result, l1.shape().clone()))
342            }
343            CpuStorage::I32(vs1) => {
344                let vs2 = s2.as_slice::<i32>().unwrap();
345                let vs1 = match l1.contiguous_offsets() {
346                    Some((a, b)) => &vs1[a..b],
347                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
348                };
349                let vs2 = match l2.contiguous_offsets() {
350                    Some((a, b)) => &vs2[a..b],
351                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
352                };
353                let result = self.bitwise(vs1, vs2);
354                let result = CpuStorage::I32(result);
355                Ok((result, l1.shape().clone()))
356            }
357            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
358        }
359    }
360
361    #[cfg(feature = "cuda")]
362    fn cuda_fwd(
363        &self,
364        s1: &CudaStorage,
365        l1: &Layout,
366        s2: &CudaStorage,
367        l2: &Layout,
368    ) -> Result<(CudaStorage, Shape)> {
369        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
370            return Err(Error::ShapeMismatchBinaryOp {
371                lhs: l1.shape().clone(),
372                rhs: l2.shape().clone(),
373                op: "bitwise-op",
374            });
375        }
376        if s1.dtype() != s2.dtype() {
377            return Err(Error::DTypeMismatchBinaryOp {
378                lhs: s1.dtype(),
379                rhs: s2.dtype(),
380                op: "bitwise-op",
381            });
382        }
383        if !l1.is_contiguous() {
384            candle_core::bail!("Input tensor s1 must be contiguous");
385        }
386        if !l2.is_contiguous() {
387            candle_core::bail!("Input tensor s2 must be contiguous");
388        }
389
390        let dev = s1.device().clone();
391        let (d_in1_ptr, d_in2_ptr, _d_in1_guard, _d_in2_guard, elem_count) = match s1.dtype() {
392            DType::U8 => {
393                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
394                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u8>()?, l2.start_offset());
395                let elem_count = l1.shape().elem_count();
396                (
397                    d_in1 as *const std::ffi::c_void,
398                    d_in2 as *const std::ffi::c_void,
399                    d_in1_guard,
400                    d_in2_guard,
401                    elem_count,
402                )
403            }
404            DType::U32 => {
405                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u32>()?, l1.start_offset());
406                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u32>()?, l2.start_offset());
407                let elem_count = l1.shape().elem_count();
408                (
409                    d_in1 as *const std::ffi::c_void,
410                    d_in2 as *const std::ffi::c_void,
411                    d_in1_guard,
412                    d_in2_guard,
413                    elem_count,
414                )
415            }
416            DType::I64 => {
417                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i64>()?, l1.start_offset());
418                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i64>()?, l2.start_offset());
419                let elem_count = l1.shape().elem_count();
420                (
421                    d_in1 as *const std::ffi::c_void,
422                    d_in2 as *const std::ffi::c_void,
423                    d_in1_guard,
424                    d_in2_guard,
425                    elem_count,
426                )
427            }
428            DType::I32 => {
429                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
430                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i32>()?, l2.start_offset());
431                let elem_count = l1.shape().elem_count();
432                (
433                    d_in1 as *const std::ffi::c_void,
434                    d_in2 as *const std::ffi::c_void,
435                    d_in1_guard,
436                    d_in2_guard,
437                    elem_count,
438                )
439            }
440            DType::I16 => {
441                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i16>()?, l1.start_offset());
442                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i16>()?, l2.start_offset());
443                let elem_count = l1.shape().elem_count();
444                (
445                    d_in1 as *const std::ffi::c_void,
446                    d_in2 as *const std::ffi::c_void,
447                    d_in1_guard,
448                    d_in2_guard,
449                    elem_count,
450                )
451            }
452            other => {
453                return Err(Error::UnsupportedDTypeForOp(other, "bitwise"));
454            }
455        };
456        let dst = match s1.dtype() {
457            DType::U8 => {
458                let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
459                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
460                unsafe {
461                    match self.op {
462                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
463                            d_in1_ptr,
464                            d_in2_ptr,
465                            d_out_ptr as *mut c_void,
466                            u32::try_from(elem_count)?,
467                        ),
468                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
469                            d_in1_ptr,
470                            d_in2_ptr,
471                            d_out_ptr as *mut c_void,
472                            u32::try_from(elem_count)?,
473                        ),
474                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
475                            d_in1_ptr,
476                            d_in2_ptr,
477                            d_out_ptr as *mut c_void,
478                            u32::try_from(elem_count)?,
479                        ),
480                    }
481                };
482                drop(d_out_guard);
483                CudaStorage::wrap_cuda_slice(d_out, dev)
484            }
485            DType::U32 => {
486                let d_out = unsafe { dev.alloc::<u32>(elem_count) }?;
487                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
488                unsafe {
489                    match self.op {
490                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
491                            d_in1_ptr,
492                            d_in2_ptr,
493                            d_out_ptr as *mut c_void,
494                            u32::try_from(elem_count)?,
495                        ),
496                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
497                            d_in1_ptr,
498                            d_in2_ptr,
499                            d_out_ptr as *mut c_void,
500                            u32::try_from(elem_count)?,
501                        ),
502                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
503                            d_in1_ptr,
504                            d_in2_ptr,
505                            d_out_ptr as *mut c_void,
506                            u32::try_from(elem_count)?,
507                        ),
508                    }
509                };
510                drop(d_out_guard);
511                CudaStorage::wrap_cuda_slice(d_out, dev)
512            }
513            DType::I64 => {
514                let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
515                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
516                unsafe {
517                    match self.op {
518                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
519                            d_in1_ptr,
520                            d_in2_ptr,
521                            d_out_ptr as *mut c_void,
522                            u32::try_from(elem_count)?,
523                        ),
524                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
525                            d_in1_ptr,
526                            d_in2_ptr,
527                            d_out_ptr as *mut c_void,
528                            u32::try_from(elem_count)?,
529                        ),
530                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
531                            d_in1_ptr,
532                            d_in2_ptr,
533                            d_out_ptr as *mut c_void,
534                            u32::try_from(elem_count)?,
535                        ),
536                    }
537                };
538                drop(d_out_guard);
539                CudaStorage::wrap_cuda_slice(d_out, dev)
540            }
541            DType::I32 => {
542                let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
543                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
544                unsafe {
545                    match self.op {
546                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
547                            d_in1_ptr,
548                            d_in2_ptr,
549                            d_out_ptr as *mut c_void,
550                            u32::try_from(elem_count)?,
551                        ),
552                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
553                            d_in1_ptr,
554                            d_in2_ptr,
555                            d_out_ptr as *mut c_void,
556                            u32::try_from(elem_count)?,
557                        ),
558                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
559                            d_in1_ptr,
560                            d_in2_ptr,
561                            d_out_ptr as *mut c_void,
562                            u32::try_from(elem_count)?,
563                        ),
564                    }
565                };
566                drop(d_out_guard);
567                CudaStorage::wrap_cuda_slice(d_out, dev)
568            }
569            _ => unreachable!(),
570        };
571        Ok((dst, l1.shape().clone()))
572    }
573
574    #[cfg(feature = "metal")]
575    fn metal_fwd(
576        &self,
577        s1: &candle_core::MetalStorage,
578        l1: &Layout,
579        s2: &candle_core::MetalStorage,
580        l2: &Layout,
581    ) -> Result<(candle_core::MetalStorage, Shape)> {
582        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
583            return Err(Error::ShapeMismatchBinaryOp {
584                lhs: l1.shape().clone(),
585                rhs: l2.shape().clone(),
586                op: "bitwise-op",
587            });
588        }
589        if s1.dtype() != s2.dtype() {
590            return Err(Error::DTypeMismatchBinaryOp {
591                lhs: s1.dtype(),
592                rhs: s2.dtype(),
593                op: "bitwise-op",
594            });
595        }
596        if !l1.is_contiguous() {
597            candle_core::bail!("Input tensor s1 must be contiguous");
598        }
599        if !l2.is_contiguous() {
600            candle_core::bail!("Input tensor s2 must be contiguous");
601        }
602
603        let encoder = s1.device().command_encoder()?;
604        encoder.set_label("bitwise-op");
605
606        let device = s1.device();
607
608        let out_shape = l1.shape().clone();
609
610        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
611
612        match self.op {
613            BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
614                device.device(),
615                &encoder,
616                &crate::metal_kernels::Kernels::new(),
617                s1.dtype(),
618                s1.buffer(),
619                s2.buffer(),
620                l1.start_offset() * s1.dtype().size_in_bytes(),
621                l2.start_offset() * s2.dtype().size_in_bytes(),
622                out_shape.elem_count(),
623                &output,
624            )
625            .map_err(candle_core::Error::wrap)?,
626            BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
627                device.device(),
628                &encoder,
629                &crate::metal_kernels::Kernels::new(),
630                s1.dtype(),
631                s1.buffer(),
632                s2.buffer(),
633                l1.start_offset() * s1.dtype().size_in_bytes(),
634                l2.start_offset() * s2.dtype().size_in_bytes(),
635                out_shape.elem_count(),
636                &output,
637            )
638            .map_err(candle_core::Error::wrap)?,
639            BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
640                device.device(),
641                &encoder,
642                &crate::metal_kernels::Kernels::new(),
643                s1.dtype(),
644                s1.buffer(),
645                s2.buffer(),
646                l1.start_offset() * s1.dtype().size_in_bytes(),
647                l2.start_offset() * s2.dtype().size_in_bytes(),
648                out_shape.elem_count(),
649                &output,
650            )
651            .map_err(candle_core::Error::wrap)?,
652        }
653
654        let newstorage = candle_core::MetalStorage::new(
655            output,
656            device.clone(),
657            out_shape.elem_count(),
658            s1.dtype(),
659        );
660        Ok((newstorage, out_shape))
661    }
662}
663
664struct BitWiseUnary {
665    pub op: BitWiseUnaryOpEnum,
666}
667
668impl BitWiseUnary {
669    pub fn new(op: BitWiseUnaryOpEnum) -> Self {
670        Self { op }
671    }
672
673    fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
674        vs1.into_par_iter()
675            .map(|v1| match self.op {
676                BitWiseUnaryOpEnum::Not => !*v1,
677            })
678            .collect()
679    }
680}
681
682impl CustomOp1 for BitWiseUnary {
683    fn name(&self) -> &'static str {
684        "bitwise-unary"
685    }
686
687    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
688        if !l1.is_contiguous() {
689            candle_core::bail!("Input tensor s1 must be contiguous");
690        }
691
692        match s1 {
693            CpuStorage::U8(vs1) => {
694                let vs1 = match l1.contiguous_offsets() {
695                    Some((a, b)) => &vs1[a..b],
696                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
697                };
698                let result = self.bitwise(vs1);
699                let result = CpuStorage::U8(result);
700                Ok((result, l1.shape().clone()))
701            }
702            CpuStorage::U32(vs1) => {
703                let vs1 = match l1.contiguous_offsets() {
704                    Some((a, b)) => &vs1[a..b],
705                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
706                };
707                let result = self.bitwise(vs1);
708                let result = CpuStorage::U32(result);
709                Ok((result, l1.shape().clone()))
710            }
711            CpuStorage::I64(vs1) => {
712                let vs1 = match l1.contiguous_offsets() {
713                    Some((a, b)) => &vs1[a..b],
714                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
715                };
716                let result = self.bitwise(vs1);
717                let result = CpuStorage::I64(result);
718                Ok((result, l1.shape().clone()))
719            }
720            CpuStorage::I16(vs1) => {
721                let vs1 = match l1.contiguous_offsets() {
722                    Some((a, b)) => &vs1[a..b],
723                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
724                };
725                let result = self.bitwise(vs1);
726                let result = CpuStorage::I16(result);
727                Ok((result, l1.shape().clone()))
728            }
729            CpuStorage::I32(vs1) => {
730                let vs1 = match l1.contiguous_offsets() {
731                    Some((a, b)) => &vs1[a..b],
732                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
733                };
734                let result = self.bitwise(vs1);
735                let result = CpuStorage::I32(result);
736                Ok((result, l1.shape().clone()))
737            }
738            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
739        }
740    }
741
742    #[cfg(feature = "cuda")]
743    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
744        todo!()
745    }
746
747    #[cfg(feature = "metal")]
748    fn metal_fwd(
749        &self,
750        s1: &candle_core::MetalStorage,
751        l1: &Layout,
752    ) -> Result<(candle_core::MetalStorage, Shape)> {
753        if !l1.is_contiguous() {
754            candle_core::bail!("Input tensor s1 must be contiguous");
755        }
756
757        let encoder = s1.device().command_encoder()?;
758        encoder.set_label("bitwise-unary-op");
759
760        let device = s1.device();
761
762        let out_shape = l1.shape().clone();
763
764        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
765
766        match self.op {
767            BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
768                device.device(),
769                &encoder,
770                &crate::metal_kernels::Kernels::new(),
771                s1.dtype(),
772                s1.buffer(),
773                l1.start_offset() * s1.dtype().size_in_bytes(),
774                out_shape.elem_count(),
775                &output,
776            )
777            .map_err(candle_core::Error::wrap)?,
778        }
779
780        let newstorage = candle_core::MetalStorage::new(
781            output,
782            device.clone(),
783            out_shape.elem_count(),
784            s1.dtype(),
785        );
786        Ok((newstorage, out_shape))
787    }
788}
789
790#[allow(dead_code)]
791pub trait BitWiseOp {
792    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
793    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
794    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
795    fn bitwise_not(&self) -> Result<Tensor>;
796}
797
798impl BitWiseOp for Tensor {
799    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
800        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
801    }
802
803    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
804        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
805    }
806
807    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
808        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
809    }
810
811    fn bitwise_not(&self) -> Result<Tensor> {
812        self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
813    }
814}
815
816// ────────────────────────────── ArgSort / Sort ────────────────────────────────
817
818#[allow(unused)]
819/// Configuration for an **argsort** (returns indices) operation.
820struct ArgSort {
821    axis: usize,
822}
823
824#[allow(unused)]
825/// Configuration for a **sort** (returns re‑ordered values) operation.
826struct Sort {
827    axis: usize,
828}
829
830impl CustomOp1 for ArgSort {
831    fn name(&self) -> &'static str {
832        "argsort"
833    }
834
835    // -------- CPU ------------------------------------------------------------
836    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
837        candle_core::bail!("ArgSort is not implemented for the CPU backend");
838    }
839
840    // -------- CUDA -----------------------------------------------------------
841    #[cfg(feature = "cuda")]
842    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
843        candle_core::bail!("ArgSort is not implemented for the CUDA backend");
844    }
845
846    // -------- Metal ----------------------------------------------------------
847    #[cfg(feature = "metal")]
848    fn metal_fwd(
849        &self,
850        s1: &candle_core::MetalStorage,
851        l1: &Layout,
852    ) -> Result<(candle_core::MetalStorage, Shape)> {
853        // Require contiguous input (same as other metal ops in this file)
854        if !l1.is_contiguous() {
855            candle_core::bail!("Input tensor s1 must be contiguous");
856        }
857
858        // Create a command encoder and label it for easy debugging in Xcode’s GPU frame‑capture
859        let encoder = s1.device().command_encoder()?;
860        encoder.set_label("argsort");
861
862        let device = s1.device();
863        let out_shape = l1.shape().clone();
864        let elem_count = out_shape.elem_count();
865
866        // Output buffer holds the sorted indices → always `U32`
867        let output = device.new_buffer(elem_count, candle_core::DType::U32, "argsort")?;
868
869        // ------------------------------------------------------------------
870        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
871        // ------------------------------------------------------------------
872        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
873
874        let dims = l1.dims();
875        let size_sorted_axis = dims[self.axis];
876        let n_rows = l1.shape().elem_count() / size_sorted_axis;
877
878        // Replicate the kernel’s internal block sizing to derive `n_blocks`
879        let tn = 4usize;
880        let mut bn = match size_sorted_axis.div_ceil(tn) {
881            v if v > 256 => 512,
882            v if v > 128 => 256,
883            v if v > 64 => 128,
884            v if v > 32 => 64,
885            _ => 32,
886        };
887        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
888            bn = 256;
889        }
890        let n_per_block = bn * tn;
891        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
892
893        // Borrow the buffers for this launch
894        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
895
896        // ------------------------------------------------------------------
897        // Build the unified SortArgs payload
898        // ------------------------------------------------------------------
899        let sort_args = crate::metal_kernels::SortArgs {
900            axis: self.axis,
901            shape: l1.dims(),
902            strides: l1.stride(),
903            out_shape: l1.dims(), // same as input for argsort
904            out_strides: l1.stride(),
905            in_contiguous: l1.is_contiguous(),
906            in_ty: s1.dtype(),
907            out_ty: candle_core::DType::U32,
908            src: s1.buffer(),
909            src_offset: l1.start_offset(), // element offset
910            dst: &output,
911            bn,
912            tn,
913            n_blocks,
914        };
915
916        // Launch the Metal kernel via the new API
917        crate::metal_kernels::call_argsort(
918            device.device(),
919            &encoder, // impl EncoderProvider
920            &crate::metal_kernels::Kernels::new(),
921            &sort_args,
922            &scratch,
923        )
924        .map_err(candle_core::Error::wrap)?;
925
926        // Wrap and return as a new MetalStorage
927        let newstorage = candle_core::MetalStorage::new(
928            output,
929            device.clone(),
930            elem_count,
931            candle_core::DType::U32,
932        );
933        Ok((newstorage, out_shape))
934    }
935}
936
937impl CustomOp1 for Sort {
938    fn name(&self) -> &'static str {
939        "sort"
940    }
941
942    // -------- CPU ------------------------------------------------------------
943    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
944        candle_core::bail!("Sort is not implemented for the CPU backend");
945    }
946
947    // -------- CUDA -----------------------------------------------------------
948    #[cfg(feature = "cuda")]
949    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
950        candle_core::bail!("Sort is not implemented for the CUDA backend");
951    }
952
953    // -------- Metal ----------------------------------------------------------
954    #[cfg(feature = "metal")]
955    fn metal_fwd(
956        &self,
957        s1: &candle_core::MetalStorage,
958        l1: &Layout,
959    ) -> Result<(candle_core::MetalStorage, Shape)> {
960        // Require contiguous input (same as other metal ops in this file)
961        if !l1.is_contiguous() {
962            candle_core::bail!("Input tensor s1 must be contiguous");
963        }
964
965        // Create a command encoder and label it for easy debugging in Xcode’s GPU frame‑capture
966        let encoder = s1.device().command_encoder()?;
967        encoder.set_label("sort");
968
969        let device = s1.device();
970        let out_shape = l1.shape().clone();
971        let elem_count = out_shape.elem_count();
972
973        // Output buffer keeps the same dtype as the input (these are the reordered values)
974        let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
975
976        // ------------------------------------------------------------------
977        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
978        // ------------------------------------------------------------------
979        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
980
981        let dims = l1.dims();
982        let size_sorted_axis = dims[self.axis];
983        let n_rows = l1.shape().elem_count() / size_sorted_axis;
984
985        // Replicate the kernel’s internal block sizing to derive `n_blocks`
986        let tn = 4usize;
987        let mut bn = match size_sorted_axis.div_ceil(tn) {
988            v if v > 256 => 512,
989            v if v > 128 => 256,
990            v if v > 64 => 128,
991            v if v > 32 => 64,
992            _ => 32,
993        };
994        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
995            bn = 256;
996        }
997        let n_per_block = bn * tn;
998        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
999
1000        // Borrow the buffers for this launch
1001        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
1002
1003        // ------------------------------------------------------------------
1004        // Build the unified SortArgs payload
1005        // ------------------------------------------------------------------
1006        let sort_args = crate::metal_kernels::SortArgs {
1007            axis: self.axis,
1008            shape: l1.dims(),
1009            strides: l1.stride(),
1010            out_shape: l1.dims(), // same shape for value sort
1011            out_strides: l1.stride(),
1012            in_contiguous: l1.is_contiguous(),
1013            in_ty: s1.dtype(),
1014            out_ty: s1.dtype(),
1015            src: s1.buffer(),
1016            src_offset: l1.start_offset(), // element offset
1017            dst: &output,
1018            bn,
1019            tn,
1020            n_blocks,
1021        };
1022
1023        // Launch the Metal kernel via the new API
1024        crate::metal_kernels::call_sort(
1025            device.device(),
1026            &encoder, // impl EncoderProvider
1027            &crate::metal_kernels::Kernels::new(),
1028            &sort_args,
1029            &scratch,
1030        )
1031        .map_err(candle_core::Error::wrap)?;
1032
1033        // Wrap and return as a new MetalStorage
1034        let newstorage =
1035            candle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1036        Ok((newstorage, out_shape))
1037    }
1038}
1039
1040/// Extension trait adding `argsort` / `sort` convenience calls on `Tensor`.
1041pub trait SortOp {
1042    /// Returns the indices that would (ascending) sort the tensor along `axis`.
1043    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1044    /// Returns the tensor's values (ascending) sorted along `axis`.
1045    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1046}
1047
1048impl SortOp for Tensor {
1049    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1050        if self.device().is_cpu() || self.device().is_cuda() {
1051            return self.arg_sort_last_dim(true);
1052        }
1053        self.apply_op1_no_bwd(&ArgSort {
1054            axis: axis.to_index(self.shape(), "argsort")?,
1055        })
1056    }
1057
1058    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1059        if self.device().is_cpu() || self.device().is_cuda() {
1060            return Ok(self.sort_last_dim(true)?.0);
1061        }
1062        self.apply_op1_no_bwd(&Sort {
1063            axis: axis.to_index(self.shape(), "sort")?,
1064        })
1065    }
1066}
1067
1068struct NonZero;
1069
1070impl NonZero {
1071    // Sequential version
1072    fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1073        let n = layout.dims().len();
1074        let mut result = Vec::new();
1075        let mut indices = vec![0u32; n];
1076        for (i, v) in vs.iter().enumerate() {
1077            if !v.is_zero() {
1078                let mut idx = i;
1079                for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1080                    let d = idx % dim;
1081                    indices[dim_index] = u32::try_from(d).unwrap();
1082                    idx /= dim;
1083                }
1084                result.extend_from_slice(&indices);
1085            }
1086        }
1087        result
1088    }
1089}
1090
1091#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1092mod cuda_ops_cccl2 {
1093    use super::*;
1094
1095    pub(super) fn count_nonzero_cuda(
1096        dtype: candle_core::DType,
1097        d_in: *const c_void,
1098        n: u32,
1099        stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1100    ) -> u32 {
1101        unsafe {
1102            match dtype {
1103                candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1104                candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1105                candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1106                candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1107                candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1108                candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1109                candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1110                candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1111                candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1112                _ => unreachable!(),
1113            }
1114        }
1115    }
1116
1117    #[allow(clippy::too_many_arguments)]
1118    pub(super) fn nonzero_cuda(
1119        dtype: candle_core::DType,
1120        d_in: *const c_void,
1121        n: u32,
1122        num_nonzero: u32,
1123        dims: *const c_void,
1124        num_dims: u32,
1125        d_out: *mut c_void,
1126        stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1127    ) {
1128        unsafe {
1129            match dtype {
1130                candle_core::DType::U8 => {
1131                    ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1132                }
1133                candle_core::DType::U32 => {
1134                    ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1135                }
1136                candle_core::DType::I64 => {
1137                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1138                }
1139                candle_core::DType::I32 => {
1140                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1141                }
1142                candle_core::DType::I16 => {
1143                    ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1144                }
1145                candle_core::DType::BF16 => {
1146                    ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1147                }
1148                candle_core::DType::F16 => {
1149                    ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1150                }
1151                candle_core::DType::F32 => {
1152                    ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1153                }
1154                candle_core::DType::F64 => {
1155                    ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1156                }
1157                _ => unreachable!(),
1158            }
1159        }
1160    }
1161}
1162
1163#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1164mod cuda_ops_cccl3 {
1165    use super::*;
1166
1167    pub(super) fn count_nonzero_cuda(
1168        dtype: candle_core::DType,
1169        d_in: *const c_void,
1170        n: u32,
1171        stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1172    ) -> u32 {
1173        unsafe {
1174            match dtype {
1175                candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1176                candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1177                candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1178                candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1179                candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1180                candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1181                candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1182                candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1183                candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1184                _ => unreachable!(),
1185            }
1186        }
1187    }
1188
1189    #[allow(clippy::too_many_arguments)]
1190    pub(super) fn nonzero_cuda(
1191        dtype: candle_core::DType,
1192        d_in: *const c_void,
1193        n: u32,
1194        num_nonzero: u32,
1195        dims: *const c_void,
1196        num_dims: u32,
1197        d_out: *mut c_void,
1198        stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1199    ) {
1200        unsafe {
1201            match dtype {
1202                candle_core::DType::U8 => {
1203                    ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1204                }
1205                candle_core::DType::U32 => {
1206                    ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1207                }
1208                candle_core::DType::I64 => {
1209                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1210                }
1211                candle_core::DType::I32 => {
1212                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1213                }
1214                candle_core::DType::I16 => {
1215                    ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1216                }
1217                candle_core::DType::BF16 => {
1218                    ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1219                }
1220                candle_core::DType::F16 => {
1221                    ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1222                }
1223                candle_core::DType::F32 => {
1224                    ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1225                }
1226                candle_core::DType::F64 => {
1227                    ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1228                }
1229                _ => unreachable!(),
1230            }
1231        }
1232    }
1233}
1234
1235#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1236use cuda_ops_cccl2::{count_nonzero_cuda, nonzero_cuda};
1237#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1238use cuda_ops_cccl3::{count_nonzero_cuda, nonzero_cuda};
1239
1240impl CustomOp1 for NonZero {
1241    fn name(&self) -> &'static str {
1242        "nonzero"
1243    }
1244
1245    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1246        if !layout.is_contiguous() {
1247            return Err(Error::RequiresContiguous { op: "nonzero" });
1248        }
1249        let result = match storage {
1250            candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1251            candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1252            candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1253            candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1254            candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1255            candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1256            candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1257            candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1258            candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1259            _ => unreachable!(),
1260        };
1261        let index_len = layout.dims().len();
1262        let result_len = result.len() / index_len;
1263        let result = CpuStorage::U32(result);
1264        let shape = Shape::from_dims(&[result_len, index_len]);
1265        Ok((result, shape))
1266    }
1267
1268    #[cfg(feature = "cuda")]
1269    fn cuda_fwd(
1270        &self,
1271        storage: &candle_core::CudaStorage,
1272        layout: &Layout,
1273    ) -> Result<(candle_core::CudaStorage, Shape)> {
1274        if !layout.is_contiguous() {
1275            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1276        }
1277        let dev = storage.device().clone();
1278        let (d_in, _d_in_guard) = match storage.dtype() {
1279            candle_core::DType::U8 => {
1280                let slice = storage.as_cuda_slice::<u8>()?;
1281                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1282                (d_in as *const std::ffi::c_void, d_in_guard)
1283            }
1284            candle_core::DType::U32 => {
1285                let slice = storage.as_cuda_slice::<u32>()?;
1286                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1287                (d_in as *const std::ffi::c_void, d_in_guard)
1288            }
1289            candle_core::DType::I32 => {
1290                let slice = storage.as_cuda_slice::<i32>()?;
1291                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1292                (d_in as *const std::ffi::c_void, d_in_guard)
1293            }
1294            candle_core::DType::I16 => {
1295                let slice = storage.as_cuda_slice::<i16>()?;
1296                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1297                (d_in as *const std::ffi::c_void, d_in_guard)
1298            }
1299            candle_core::DType::I64 => {
1300                let slice = storage.as_cuda_slice::<i64>()?;
1301                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1302                (d_in as *const std::ffi::c_void, d_in_guard)
1303            }
1304            candle_core::DType::BF16 => {
1305                let slice = storage.as_cuda_slice::<half::bf16>()?;
1306                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1307                (d_in as *const std::ffi::c_void, d_in_guard)
1308            }
1309            candle_core::DType::F16 => {
1310                let slice = storage.as_cuda_slice::<half::f16>()?;
1311                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1312                (d_in as *const std::ffi::c_void, d_in_guard)
1313            }
1314            candle_core::DType::F32 => {
1315                let slice = storage.as_cuda_slice::<f32>()?;
1316                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1317                (d_in as *const std::ffi::c_void, d_in_guard)
1318            }
1319            candle_core::DType::F64 => {
1320                let slice = storage.as_cuda_slice::<f64>()?;
1321                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1322                (d_in as *const std::ffi::c_void, d_in_guard)
1323            }
1324            _ => unreachable!(),
1325        };
1326        let n = layout.shape().elem_count();
1327
1328        let num_nonzero = count_nonzero_cuda(
1329            storage.dtype(),
1330            d_in,
1331            u32::try_from(n)?,
1332            dev.cuda_stream().cu_stream(),
1333        );
1334        let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1335            .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1336        if num_nonzero != 0 {
1337            let (d_out, _d_out_guard) = d_out.device_ptr(d_out.stream());
1338            let dims = layout
1339                .dims()
1340                .iter()
1341                .map(|&x| u32::try_from(x).unwrap())
1342                .collect::<Vec<u32>>();
1343            let mut d_dims = unsafe { dev.alloc::<u32>(dims.len()) }?;
1344            dev.memcpy_htod(&dims, &mut d_dims)?;
1345            let (d_dims_ptr, _d_dims_guard) = d_dims.device_ptr(d_dims.stream());
1346            nonzero_cuda(
1347                storage.dtype(),
1348                d_in,
1349                u32::try_from(n)?,
1350                num_nonzero,
1351                d_dims_ptr as *const c_void,
1352                u32::try_from(layout.dims().len())?,
1353                d_out as *mut c_void,
1354                dev.cuda_stream().cu_stream(),
1355            );
1356        }
1357        let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1358        let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
1359        Ok((dst, shape))
1360    }
1361}
1362
1363pub trait NonZeroOp {
1364    fn nonzero(&self) -> Result<Tensor>;
1365}
1366
1367impl NonZeroOp for Tensor {
1368    #[cfg(feature = "metal")]
1369    fn nonzero(&self) -> Result<Tensor> {
1370        if !self.is_contiguous() {
1371            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1372        }
1373        let original_device = self.device();
1374        self.to_device(&candle_core::Device::Cpu)?
1375            .apply_op1_no_bwd(&NonZero)?
1376            .to_device(original_device)
1377    }
1378
1379    #[cfg(not(feature = "metal"))]
1380    fn nonzero(&self) -> Result<Tensor> {
1381        if !self.is_contiguous() {
1382            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1383        }
1384        self.apply_op1_no_bwd(&NonZero)
1385    }
1386}
1387
1388struct CumSum {
1389    inclusive: bool,
1390    reverse: bool,
1391    axis: usize,
1392}
1393
1394impl CustomOp1 for CumSum {
1395    fn name(&self) -> &'static str {
1396        "cumsum"
1397    }
1398
1399    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1400        use std::ops::Add;
1401        if !l1.is_contiguous() {
1402            candle_core::bail!("Input tensor s1 must be contiguous");
1403        }
1404        let dims = l1.dims();
1405        let axis = self.axis;
1406        let axis_len = dims[axis];
1407        let (start, end) = l1
1408            .contiguous_offsets()
1409            .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1410
1411        // helper to execute scan for a slice of T
1412        macro_rules! scan_block {
1413            ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1414                let vs: &[$ty] = $vt;
1415                let input = &vs[start..end];
1416                let count = input.len() / axis_len;
1417                let mut result = Vec::<$ty>::with_capacity(input.len());
1418                if !self.reverse {
1419                    if self.inclusive {
1420                        for block in 0..count {
1421                            let base = block * axis_len;
1422                            let mut sum = input[base];
1423                            result.push(sum);
1424                            for j in 1..axis_len {
1425                                sum = sum.$add(input[base + j]);
1426                                result.push(sum);
1427                            }
1428                        }
1429                    } else {
1430                        let init: $ty = $init;
1431                        for block in 0..count {
1432                            let base = block * axis_len;
1433                            let mut sum = init;
1434                            for j in 0..axis_len {
1435                                result.push(sum);
1436                                sum = sum.$add(input[base + j]);
1437                            }
1438                        }
1439                    }
1440                } else {
1441                    if self.inclusive {
1442                        for block in 0..count {
1443                            let base = block * axis_len;
1444                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1445                            let mut sum = input[base + axis_len - 1];
1446                            temp.push(sum);
1447                            for k in 1..axis_len {
1448                                let idx = axis_len - 1 - k;
1449                                sum = sum.$add(input[base + idx]);
1450                                temp.push(sum);
1451                            }
1452                            temp.reverse();
1453                            result.extend(temp);
1454                        }
1455                    } else {
1456                        let init: $ty = $init;
1457                        for block in 0..count {
1458                            let base = block * axis_len;
1459                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1460                            let mut sum = init;
1461                            for k in 0..axis_len {
1462                                let idx = axis_len - 1 - k;
1463                                temp.push(sum);
1464                                sum = sum.$add(input[base + idx]);
1465                            }
1466                            temp.reverse();
1467                            result.extend(temp);
1468                        }
1469                    }
1470                }
1471                result
1472            }};
1473        }
1474        match s1 {
1475            CpuStorage::U8(vs) => {
1476                let result = scan_block!(vs, u8, wrapping_add, 0u8);
1477                Ok((CpuStorage::U8(result), l1.shape().clone()))
1478            }
1479            CpuStorage::I16(vs) => {
1480                let result = scan_block!(vs, i16, add, 0i16);
1481                Ok((CpuStorage::I16(result), l1.shape().clone()))
1482            }
1483            CpuStorage::U32(vs) => {
1484                let result = scan_block!(vs, u32, wrapping_add, 0u32);
1485                Ok((CpuStorage::U32(result), l1.shape().clone()))
1486            }
1487            CpuStorage::I32(vs) => {
1488                let result = scan_block!(vs, i32, add, 0i32);
1489                Ok((CpuStorage::I32(result), l1.shape().clone()))
1490            }
1491            CpuStorage::I64(vs) => {
1492                let result = scan_block!(vs, i64, add, 0i64);
1493                Ok((CpuStorage::I64(result), l1.shape().clone()))
1494            }
1495            CpuStorage::F32(vs) => {
1496                let result = scan_block!(vs, f32, add, 0.0f32);
1497                Ok((CpuStorage::F32(result), l1.shape().clone()))
1498            }
1499            CpuStorage::F64(vs) => {
1500                let result = scan_block!(vs, f64, add, 0.0f64);
1501                Ok((CpuStorage::F64(result), l1.shape().clone()))
1502            }
1503            _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1504        }
1505    }
1506
1507    #[cfg(feature = "cuda")]
1508    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1509        todo!()
1510    }
1511
1512    #[cfg(feature = "metal")]
1513    fn metal_fwd(
1514        &self,
1515        s1: &candle_core::MetalStorage,
1516        l1: &Layout,
1517    ) -> Result<(candle_core::MetalStorage, Shape)> {
1518        use crate::metal_kernels::ScanType;
1519
1520        let encoder = s1.device().command_encoder()?;
1521        encoder.set_label("cumsum");
1522
1523        let device = s1.device();
1524
1525        let out_shape = l1.shape().clone();
1526
1527        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1528
1529        crate::metal_kernels::call_scan(
1530            device.device(),
1531            &encoder,
1532            &crate::metal_kernels::Kernels::new(),
1533            s1.dtype(),
1534            ScanType::Sum,
1535            s1.buffer(),
1536            l1.start_offset() * s1.dtype().size_in_bytes(),
1537            self.axis,
1538            l1.dims(),
1539            l1.stride(),
1540            self.reverse,
1541            self.inclusive,
1542            &output,
1543        )
1544        .map_err(candle_core::Error::wrap)?;
1545
1546        let newstorage = candle_core::MetalStorage::new(
1547            output,
1548            device.clone(),
1549            out_shape.elem_count(),
1550            s1.dtype(),
1551        );
1552        Ok((newstorage, out_shape))
1553    }
1554}
1555
1556#[allow(dead_code)]
1557pub trait CumSumOp {
1558    /// inclusive = false, reverse = false
1559    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1560
1561    fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1562        -> Result<Tensor>;
1563}
1564
1565impl CumSumOp for Tensor {
1566    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1567        self.fast_cumsum_config(axis, false, false)
1568    }
1569
1570    fn fast_cumsum_config<D: Dim>(
1571        &self,
1572        axis: D,
1573        inclusive: bool,
1574        reverse: bool,
1575    ) -> Result<Tensor> {
1576        self.apply_op1_no_bwd(&CumSum {
1577            inclusive,
1578            reverse,
1579            axis: axis.to_index(self.shape(), "cumsum")?,
1580        })
1581    }
1582}
1583
1584/// Fused GPT-OSS SwiGLU activation
1585/// Formula: output = (clamp(up, -limit, limit) + 1) * gate_clamped * sigmoid(gate_clamped * alpha)
1586/// where gate_clamped = min(gate, limit)
1587#[cfg(feature = "cuda")]
1588pub fn gptoss_swiglu_fused(gate: &Tensor, up: &Tensor, alpha: f32, limit: f32) -> Result<Tensor> {
1589    use half::{bf16, f16};
1590
1591    let gate = gate.contiguous()?;
1592    let up = up.contiguous()?;
1593
1594    if gate.shape() != up.shape() {
1595        candle_core::bail!(
1596            "gptoss_swiglu: gate and up must have same shape, got {:?} vs {:?}",
1597            gate.shape(),
1598            up.shape()
1599        );
1600    }
1601
1602    let device = match gate.device() {
1603        candle_core::Device::Cuda(dev) => dev,
1604        _ => candle_core::bail!("gptoss_swiglu requires CUDA device"),
1605    };
1606
1607    let n_elements = gate.elem_count();
1608    let dtype = gate.dtype();
1609
1610    let gate_storage = gate.storage_and_layout().0;
1611    let up_storage = up.storage_and_layout().0;
1612
1613    let gate_cuda = match &*gate_storage {
1614        candle_core::Storage::Cuda(s) => s,
1615        _ => candle_core::bail!("Expected CUDA storage for gate"),
1616    };
1617    let up_cuda = match &*up_storage {
1618        candle_core::Storage::Cuda(s) => s,
1619        _ => candle_core::bail!("Expected CUDA storage for up"),
1620    };
1621
1622    let stream = device.cuda_stream().cu_stream();
1623
1624    match dtype {
1625        DType::F16 => {
1626            let output = device.alloc_zeros::<f16>(n_elements)?;
1627            let gate_slice = gate_cuda.as_cuda_slice::<f16>()?;
1628            let up_slice = up_cuda.as_cuda_slice::<f16>()?;
1629
1630            let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1631            let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1632            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1633
1634            unsafe {
1635                ffi::gptoss_swiglu_f16(
1636                    gate_ptr as *const c_void,
1637                    up_ptr as *const c_void,
1638                    out_ptr as *mut c_void,
1639                    n_elements as u32,
1640                    alpha,
1641                    limit,
1642                    stream,
1643                );
1644            }
1645
1646            drop(_o_guard);
1647            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1648            Ok(Tensor::from((
1649                candle_core::Storage::Cuda(out_storage),
1650                gate.shape().clone(),
1651            )))
1652        }
1653        DType::BF16 => {
1654            let output = device.alloc_zeros::<bf16>(n_elements)?;
1655            let gate_slice = gate_cuda.as_cuda_slice::<bf16>()?;
1656            let up_slice = up_cuda.as_cuda_slice::<bf16>()?;
1657
1658            let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1659            let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1660            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1661
1662            unsafe {
1663                ffi::gptoss_swiglu_bf16(
1664                    gate_ptr as *const c_void,
1665                    up_ptr as *const c_void,
1666                    out_ptr as *mut c_void,
1667                    n_elements as u32,
1668                    alpha,
1669                    limit,
1670                    stream,
1671                );
1672            }
1673
1674            drop(_o_guard);
1675            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1676            Ok(Tensor::from((
1677                candle_core::Storage::Cuda(out_storage),
1678                gate.shape().clone(),
1679            )))
1680        }
1681        DType::F32 => {
1682            let output = device.alloc_zeros::<f32>(n_elements)?;
1683            let gate_slice = gate_cuda.as_cuda_slice::<f32>()?;
1684            let up_slice = up_cuda.as_cuda_slice::<f32>()?;
1685
1686            let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1687            let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1688            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1689
1690            unsafe {
1691                ffi::gptoss_swiglu_f32(
1692                    gate_ptr as *const c_void,
1693                    up_ptr as *const c_void,
1694                    out_ptr as *mut c_void,
1695                    n_elements as u32,
1696                    alpha,
1697                    limit,
1698                    stream,
1699                );
1700            }
1701
1702            drop(_o_guard);
1703            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1704            Ok(Tensor::from((
1705                candle_core::Storage::Cuda(out_storage),
1706                gate.shape().clone(),
1707            )))
1708        }
1709        _ => candle_core::bail!("gptoss_swiglu: unsupported dtype {:?}", dtype),
1710    }
1711}
1712
1713/// Fused GPT-OSS SwiGLU for interleaved gate/up data.
1714///
1715/// This handles interleaved gate/up format directly, avoiding 2 tensor copies
1716/// from narrow().squeeze().contiguous().
1717///
1718/// Args:
1719///   gate_up: [N, intermediate_size, 2] - interleaved gate/up data
1720///   alpha: SwiGLU alpha parameter
1721///   limit: SwiGLU limit parameter
1722///
1723/// Returns: [N, intermediate_size] - activated output
1724#[cfg(feature = "cuda")]
1725pub fn gptoss_swiglu_interleaved(
1726    gate_up: &Tensor,
1727    intermediate_size: usize,
1728    alpha: f32,
1729    limit: f32,
1730) -> Result<Tensor> {
1731    use half::{bf16, f16};
1732    use std::ffi::c_void;
1733
1734    let gate_up = gate_up.contiguous()?;
1735
1736    let dims = gate_up.dims();
1737    if dims.len() != 3 || dims[2] != 2 {
1738        candle_core::bail!(
1739            "gptoss_swiglu_interleaved: expected gate_up shape [N, intermediate_size, 2], got {:?}",
1740            dims
1741        );
1742    }
1743
1744    let n = dims[0]; // num_tokens * topk
1745    let device = match gate_up.device() {
1746        candle_core::Device::Cuda(dev) => dev,
1747        _ => candle_core::bail!("gptoss_swiglu_interleaved requires CUDA device"),
1748    };
1749
1750    let dtype = gate_up.dtype();
1751    let n_output_elements = n * intermediate_size;
1752
1753    let gate_up_storage = gate_up.storage_and_layout().0;
1754    let gate_up_cuda = match &*gate_up_storage {
1755        candle_core::Storage::Cuda(s) => s,
1756        _ => candle_core::bail!("Expected CUDA storage for gate_up"),
1757    };
1758
1759    let stream = device.cuda_stream().cu_stream();
1760
1761    match dtype {
1762        DType::F16 => {
1763            let output = device.alloc_zeros::<f16>(n_output_elements)?;
1764            let gate_up_slice = gate_up_cuda.as_cuda_slice::<f16>()?;
1765
1766            let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1767            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1768
1769            unsafe {
1770                ffi::gptoss_swiglu_interleaved_f16(
1771                    gate_up_ptr as *const c_void,
1772                    out_ptr as *mut c_void,
1773                    n as u32,
1774                    intermediate_size as u32,
1775                    alpha,
1776                    limit,
1777                    stream,
1778                );
1779            }
1780
1781            drop(_o_guard);
1782            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1783            Ok(Tensor::from((
1784                candle_core::Storage::Cuda(out_storage),
1785                Shape::from(vec![n, intermediate_size]),
1786            )))
1787        }
1788        DType::BF16 => {
1789            let output = device.alloc_zeros::<bf16>(n_output_elements)?;
1790            let gate_up_slice = gate_up_cuda.as_cuda_slice::<bf16>()?;
1791
1792            let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1793            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1794
1795            unsafe {
1796                ffi::gptoss_swiglu_interleaved_bf16(
1797                    gate_up_ptr as *const c_void,
1798                    out_ptr as *mut c_void,
1799                    n as u32,
1800                    intermediate_size as u32,
1801                    alpha,
1802                    limit,
1803                    stream,
1804                );
1805            }
1806
1807            drop(_o_guard);
1808            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1809            Ok(Tensor::from((
1810                candle_core::Storage::Cuda(out_storage),
1811                Shape::from(vec![n, intermediate_size]),
1812            )))
1813        }
1814        DType::F32 => {
1815            let output = device.alloc_zeros::<f32>(n_output_elements)?;
1816            let gate_up_slice = gate_up_cuda.as_cuda_slice::<f32>()?;
1817
1818            let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1819            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1820
1821            unsafe {
1822                ffi::gptoss_swiglu_interleaved_f32(
1823                    gate_up_ptr as *const c_void,
1824                    out_ptr as *mut c_void,
1825                    n as u32,
1826                    intermediate_size as u32,
1827                    alpha,
1828                    limit,
1829                    stream,
1830                );
1831            }
1832
1833            drop(_o_guard);
1834            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1835            Ok(Tensor::from((
1836                candle_core::Storage::Cuda(out_storage),
1837                Shape::from(vec![n, intermediate_size]),
1838            )))
1839        }
1840        _ => candle_core::bail!("gptoss_swiglu_interleaved: unsupported dtype {:?}", dtype),
1841    }
1842}
1843
1844/// Fused softmax with sinks for GPT-OSS attention.
1845///
1846/// This computes softmax over attention logits while including a per-head "sink" value
1847/// in the normalization, then drops the sink from the output.
1848///
1849/// Args:
1850///   logits: [batch, heads, q_len, k_len] - attention scores (q @ k.T * scale)
1851///   sinks: [heads] - per-head sink values
1852///   mask: Optional [batch, 1, q_len, k_len] - attention mask (0 = attend, -inf = mask)
1853///
1854/// Returns: [batch, heads, q_len, k_len] - softmax probabilities (sink dropped from normalization)
1855#[cfg(feature = "cuda")]
1856pub fn softmax_with_sinks(
1857    logits: &Tensor,
1858    sinks: &Tensor,
1859    mask: Option<&Tensor>,
1860) -> Result<Tensor> {
1861    use half::{bf16, f16};
1862    use std::ffi::c_void;
1863
1864    let logits = logits.contiguous()?;
1865    let sinks = sinks.contiguous()?;
1866
1867    let dims = logits.dims();
1868    if dims.len() != 4 {
1869        candle_core::bail!(
1870            "softmax_with_sinks: expected logits to have 4 dims [b, h, q, k], got {:?}",
1871            dims
1872        );
1873    }
1874
1875    let batch_size = dims[0];
1876    let num_heads = dims[1];
1877    let q_len = dims[2];
1878    let k_len = dims[3];
1879
1880    if sinks.dims() != [num_heads] {
1881        candle_core::bail!(
1882            "softmax_with_sinks: expected sinks shape [{}], got {:?}",
1883            num_heads,
1884            sinks.dims()
1885        );
1886    }
1887
1888    let device = match logits.device() {
1889        candle_core::Device::Cuda(dev) => dev,
1890        _ => candle_core::bail!("softmax_with_sinks requires CUDA device"),
1891    };
1892
1893    let dtype = logits.dtype();
1894    let n_elements = logits.elem_count();
1895
1896    let logits_storage = logits.storage_and_layout().0;
1897    let sinks_storage = sinks.storage_and_layout().0;
1898
1899    let logits_cuda = match &*logits_storage {
1900        candle_core::Storage::Cuda(s) => s,
1901        _ => candle_core::bail!("Expected CUDA storage for logits"),
1902    };
1903    let sinks_cuda = match &*sinks_storage {
1904        candle_core::Storage::Cuda(s) => s,
1905        _ => candle_core::bail!("Expected CUDA storage for sinks"),
1906    };
1907
1908    // Handle optional mask
1909    let mask = if let Some(m) = mask {
1910        Some(m.contiguous()?)
1911    } else {
1912        None
1913    };
1914
1915    let stream = device.cuda_stream().cu_stream();
1916
1917    match dtype {
1918        DType::F16 => {
1919            let output = device.alloc_zeros::<f16>(n_elements)?;
1920
1921            let logits_slice = logits_cuda.as_cuda_slice::<f16>()?;
1922            let sinks_slice = sinks_cuda.as_cuda_slice::<f16>()?;
1923
1924            let (logits_ptr, _l_guard) = slice_ptr(logits_slice, 0);
1925            let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, 0);
1926            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1927
1928            let mask_ptr = if let Some(ref m) = mask {
1929                let m_storage = m.storage_and_layout().0;
1930                let m_cuda = match &*m_storage {
1931                    candle_core::Storage::Cuda(s) => s,
1932                    _ => candle_core::bail!("Expected CUDA storage for mask"),
1933                };
1934                let m_slice = m_cuda.as_cuda_slice::<f16>()?;
1935                let (ptr, _guard) = slice_ptr(m_slice, 0);
1936                ptr as *const c_void
1937            } else {
1938                std::ptr::null()
1939            };
1940
1941            unsafe {
1942                ffi::softmax_with_sinks_f16(
1943                    logits_ptr as *const c_void,
1944                    sinks_ptr as *const c_void,
1945                    mask_ptr,
1946                    out_ptr as *mut c_void,
1947                    batch_size as i32,
1948                    num_heads as i32,
1949                    q_len as i32,
1950                    k_len as i32,
1951                    1.0, // scale already applied to logits
1952                    stream,
1953                );
1954            }
1955
1956            drop(_o_guard);
1957            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1958            Ok(Tensor::from((
1959                candle_core::Storage::Cuda(out_storage),
1960                logits.shape().clone(),
1961            )))
1962        }
1963        DType::BF16 => {
1964            let output = device.alloc_zeros::<bf16>(n_elements)?;
1965
1966            let logits_slice = logits_cuda.as_cuda_slice::<bf16>()?;
1967            let sinks_slice = sinks_cuda.as_cuda_slice::<bf16>()?;
1968
1969            let (logits_ptr, _l_guard) = slice_ptr(logits_slice, 0);
1970            let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, 0);
1971            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1972
1973            let mask_ptr = if let Some(ref m) = mask {
1974                let m_storage = m.storage_and_layout().0;
1975                let m_cuda = match &*m_storage {
1976                    candle_core::Storage::Cuda(s) => s,
1977                    _ => candle_core::bail!("Expected CUDA storage for mask"),
1978                };
1979                let m_slice = m_cuda.as_cuda_slice::<bf16>()?;
1980                let (ptr, _guard) = slice_ptr(m_slice, 0);
1981                ptr as *const c_void
1982            } else {
1983                std::ptr::null()
1984            };
1985
1986            unsafe {
1987                ffi::softmax_with_sinks_bf16(
1988                    logits_ptr as *const c_void,
1989                    sinks_ptr as *const c_void,
1990                    mask_ptr,
1991                    out_ptr as *mut c_void,
1992                    batch_size as i32,
1993                    num_heads as i32,
1994                    q_len as i32,
1995                    k_len as i32,
1996                    1.0,
1997                    stream,
1998                );
1999            }
2000
2001            drop(_o_guard);
2002            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2003            Ok(Tensor::from((
2004                candle_core::Storage::Cuda(out_storage),
2005                logits.shape().clone(),
2006            )))
2007        }
2008        DType::F32 => {
2009            let output = device.alloc_zeros::<f32>(n_elements)?;
2010
2011            let logits_slice = logits_cuda.as_cuda_slice::<f32>()?;
2012            let sinks_slice = sinks_cuda.as_cuda_slice::<f32>()?;
2013
2014            let (logits_ptr, _l_guard) = slice_ptr(logits_slice, 0);
2015            let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, 0);
2016            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2017
2018            let mask_ptr = if let Some(ref m) = mask {
2019                let m_storage = m.storage_and_layout().0;
2020                let m_cuda = match &*m_storage {
2021                    candle_core::Storage::Cuda(s) => s,
2022                    _ => candle_core::bail!("Expected CUDA storage for mask"),
2023                };
2024                let m_slice = m_cuda.as_cuda_slice::<f32>()?;
2025                let (ptr, _guard) = slice_ptr(m_slice, 0);
2026                ptr as *const c_void
2027            } else {
2028                std::ptr::null()
2029            };
2030
2031            unsafe {
2032                ffi::softmax_with_sinks_f32(
2033                    logits_ptr as *const c_void,
2034                    sinks_ptr as *const c_void,
2035                    mask_ptr,
2036                    out_ptr as *mut c_void,
2037                    batch_size as i32,
2038                    num_heads as i32,
2039                    q_len as i32,
2040                    k_len as i32,
2041                    1.0,
2042                    stream,
2043                );
2044            }
2045
2046            drop(_o_guard);
2047            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2048            Ok(Tensor::from((
2049                candle_core::Storage::Cuda(out_storage),
2050                logits.shape().clone(),
2051            )))
2052        }
2053        _ => candle_core::bail!("softmax_with_sinks: unsupported dtype {:?}", dtype),
2054    }
2055}
2056
2057mod tests {
2058    #[test]
2059    fn test_cumsum_exclusive_forward_cpu() {
2060        use crate::utils::ops::CumSumOp;
2061        use candle_core::Tensor;
2062        let device = candle_core::Device::Cpu;
2063        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2064        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
2065        assert_eq!(b, [0, 1, 3, 6]);
2066    }
2067
2068    #[test]
2069    fn test_cumsum_inclusive_forward_cpu() {
2070        use crate::utils::ops::CumSumOp;
2071        use candle_core::Tensor;
2072        let device = candle_core::Device::Cpu;
2073        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2074        let b = a
2075            .fast_cumsum_config(0, true, false)
2076            .unwrap()
2077            .to_vec1::<i64>()
2078            .unwrap();
2079        assert_eq!(b, [1, 3, 6, 10]);
2080    }
2081
2082    #[test]
2083    fn test_cumsum_exclusive_reverse_cpu() {
2084        use crate::utils::ops::CumSumOp;
2085        use candle_core::Tensor;
2086        let device = candle_core::Device::Cpu;
2087        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2088        let b = a
2089            .fast_cumsum_config(0, false, true)
2090            .unwrap()
2091            .to_vec1::<i64>()
2092            .unwrap();
2093        assert_eq!(b, [9, 7, 4, 0]);
2094    }
2095
2096    #[test]
2097    fn test_cumsum_inclusive_reverse_cpu() {
2098        use crate::utils::ops::CumSumOp;
2099        use candle_core::Tensor;
2100        let device = candle_core::Device::Cpu;
2101        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2102        let b = a
2103            .fast_cumsum_config(0, true, true)
2104            .unwrap()
2105            .to_vec1::<i64>()
2106            .unwrap();
2107        assert_eq!(b, [10, 9, 7, 4]);
2108    }
2109
2110    #[cfg(feature = "metal")]
2111    #[test]
2112    fn test_cumsum_exclusive_forward_metal() {
2113        use crate::utils::ops::CumSumOp;
2114        use candle_core::Tensor;
2115        let device = candle_core::Device::new_metal(0).unwrap();
2116        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2117        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
2118        assert_eq!(b, [0, 1, 3, 6]);
2119    }
2120
2121    #[cfg(feature = "metal")]
2122    #[test]
2123    fn test_cumsum_inclusive_forward_metal() {
2124        use crate::utils::ops::CumSumOp;
2125        use candle_core::Tensor;
2126        let device = candle_core::Device::new_metal(0).unwrap();
2127        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2128        let b = a
2129            .fast_cumsum_config(0, true, false)
2130            .unwrap()
2131            .to_vec1::<i64>()
2132            .unwrap();
2133        assert_eq!(b, [1, 3, 6, 10]);
2134    }
2135
2136    #[cfg(feature = "metal")]
2137    #[test]
2138    fn test_cumsum_exclusive_reverse_metal() {
2139        use crate::utils::ops::CumSumOp;
2140        use candle_core::Tensor;
2141        let device = candle_core::Device::new_metal(0).unwrap();
2142        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2143        let b = a
2144            .fast_cumsum_config(0, false, true)
2145            .unwrap()
2146            .to_vec1::<i64>()
2147            .unwrap();
2148        assert_eq!(b, [9, 7, 4, 0]);
2149    }
2150
2151    #[cfg(feature = "metal")]
2152    #[test]
2153    fn test_cumsum_inclusive_reverse_metal() {
2154        use crate::utils::ops::CumSumOp;
2155        use candle_core::Tensor;
2156        let device = candle_core::Device::new_metal(0).unwrap();
2157        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2158        let b = a
2159            .fast_cumsum_config(0, true, true)
2160            .unwrap()
2161            .to_vec1::<i64>()
2162            .unwrap();
2163        assert_eq!(b, [10, 9, 7, 4]);
2164    }
2165
2166    #[test]
2167    fn test_nonzero_cpu() {
2168        use crate::utils::ops::NonZeroOp;
2169        use candle_core::Tensor;
2170        let device = candle_core::Device::Cpu;
2171        let a = Tensor::from_vec(
2172            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
2173            &[2, 4],
2174            &device,
2175        )
2176        .unwrap();
2177        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
2178        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
2179    }
2180
2181    #[cfg(feature = "cuda")]
2182    #[test]
2183    fn test_nonzero_cuda() {
2184        use crate::utils::ops::NonZeroOp;
2185        use candle_core::Tensor;
2186        let device = candle_core::Device::new_cuda(0).unwrap();
2187        let a = Tensor::from_vec(
2188            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
2189            &[2, 4],
2190            &device,
2191        )
2192        .unwrap();
2193        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
2194        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
2195    }
2196
2197    #[test]
2198    fn test_bitwise_and_cpu() {
2199        use crate::utils::ops::BitWiseOp;
2200        use candle_core::Tensor;
2201        let device = candle_core::Device::Cpu;
2202        let a =
2203            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2204        let b =
2205            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2206        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
2207        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
2208    }
2209
2210    #[cfg(feature = "cuda")]
2211    #[test]
2212    fn test_bitwise_and_cuda() {
2213        use crate::utils::ops::BitWiseOp;
2214        use candle_core::Tensor;
2215        let device = candle_core::Device::new_cuda(0).unwrap();
2216        let a =
2217            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2218        let b =
2219            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
2220        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
2221        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
2222    }
2223
2224    #[test]
2225    fn test_bitwise_or_cpu() {
2226        use crate::utils::ops::BitWiseOp;
2227        use candle_core::Tensor;
2228        let device = candle_core::Device::Cpu;
2229        let a =
2230            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2231        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
2232        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
2233        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
2234    }
2235
2236    #[cfg(feature = "cuda")]
2237    #[test]
2238    fn test_bitwise_or_cuda() {
2239        use crate::utils::ops::BitWiseOp;
2240        use candle_core::Tensor;
2241        let device = candle_core::Device::new_cuda(0).unwrap();
2242        let a =
2243            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2244        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
2245        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
2246        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
2247    }
2248
2249    #[test]
2250    fn test_bitwise_xor_cpu() {
2251        use crate::utils::ops::BitWiseOp;
2252        use candle_core::Tensor;
2253        let device = candle_core::Device::Cpu;
2254        let a =
2255            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2256        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
2257        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
2258        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
2259    }
2260
2261    #[cfg(feature = "cuda")]
2262    #[test]
2263    fn test_bitwise_xor_cuda() {
2264        use crate::utils::ops::BitWiseOp;
2265        use candle_core::Tensor;
2266        let device = candle_core::Device::new_cuda(0).unwrap();
2267        let a =
2268            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2269        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
2270        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
2271        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
2272    }
2273
2274    #[test]
2275    fn test_nonzero_and() {
2276        use crate::utils::ops::{BitWiseOp, NonZeroOp};
2277        use candle_core::{Device, Tensor};
2278
2279        let input1 = Tensor::from_vec(
2280            vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
2281            (10,),
2282            &Device::Cpu,
2283        )
2284        .unwrap();
2285        let input2 = Tensor::from_vec(
2286            vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
2287            (10,),
2288            &Device::Cpu,
2289        )
2290        .unwrap();
2291        let input = Tensor::stack(&[input1, input2], 0).unwrap();
2292
2293        let lt = input.lt(0.0).unwrap();
2294        let gt = input.gt(-10.0).unwrap();
2295        let res = lt
2296            .bitwise_and(&gt)
2297            .unwrap()
2298            .nonzero()
2299            .unwrap()
2300            .to_vec2::<u32>()
2301            .unwrap();
2302
2303        assert_eq!(
2304            res,
2305            [
2306                [0, 3],
2307                [0, 4],
2308                [0, 5],
2309                [0, 6],
2310                [1, 0],
2311                [1, 3],
2312                [1, 5],
2313                [1, 6]
2314            ]
2315        );
2316    }
2317
2318    #[cfg(feature = "cuda")]
2319    #[test]
2320    fn nonzero_and_cuda() {
2321        use crate::utils::ops::{BitWiseOp, NonZeroOp};
2322        use candle_core::{Device, Tensor};
2323
2324        let device = Device::new_cuda(0).unwrap();
2325        let input1 =
2326            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
2327        let input2 =
2328            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
2329        let input = Tensor::stack(&[input1, input2], 0).unwrap();
2330
2331        let lt = input.lt(0.0).unwrap();
2332        let gt = input.gt(-10.0).unwrap();
2333        let res = lt
2334            .bitwise_and(&gt)
2335            .unwrap()
2336            .nonzero()
2337            .unwrap()
2338            .to_vec2::<u32>()
2339            .unwrap();
2340
2341        assert_eq!(
2342            res,
2343            [
2344                [0, 3],
2345                [0, 4],
2346                [0, 5],
2347                [0, 6],
2348                [1, 0],
2349                [1, 3],
2350                [1, 5],
2351                [1, 6]
2352            ]
2353        );
2354    }
2355
2356    #[test]
2357    fn test_bitpack_8bit_cpu() {
2358        use crate::HqqBits;
2359        use candle_core::{Device, Tensor};
2360        let bits = HqqBits::Eight;
2361        let device = Device::Cpu;
2362        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
2363        let c = bits.bitpack_type()(wq.clone())
2364            .unwrap()
2365            .to_vec2::<u8>()
2366            .unwrap();
2367        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
2368    }
2369
2370    #[cfg(feature = "cuda")]
2371    #[test]
2372    fn test_bitpack_8bit_cuda() {
2373        use crate::HqqBits;
2374        use candle_core::DType;
2375        use candle_core::{Device, Tensor};
2376        let bits = HqqBits::Eight;
2377        let device = Device::new_cuda(0).unwrap();
2378        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
2379        let c = bits.bitpack_type()(wq.clone())
2380            .unwrap()
2381            .to_dtype(DType::U8)
2382            .unwrap()
2383            .to_vec2::<u8>()
2384            .unwrap();
2385        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
2386    }
2387
2388    #[cfg(feature = "metal")]
2389    #[test]
2390    fn test_bitpack_8bit_metal() {
2391        use crate::HqqBits;
2392        use candle_core::{Device, Tensor};
2393        let bits = HqqBits::Eight;
2394        let device = Device::new_metal(0).unwrap();
2395        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
2396        let c = bits.bitpack_type()(wq.clone())
2397            .unwrap()
2398            .to_vec2::<u8>()
2399            .unwrap();
2400        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
2401    }
2402
2403    #[test]
2404    fn test_bitpack_4bit() {
2405        use crate::HqqBits;
2406        use candle_core::{Device, Tensor};
2407        let bits = HqqBits::Four;
2408        let device = Device::Cpu;
2409        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
2410        let c = bits.bitpack_type()(wq.clone())
2411            .unwrap()
2412            .to_vec2::<u8>()
2413            .unwrap();
2414        assert_eq!(c, [[19, 36]]);
2415    }
2416
2417    #[cfg(feature = "cuda")]
2418    #[test]
2419    fn test_bitpack_4bit_cuda() {
2420        use crate::HqqBits;
2421        use candle_core::{Device, Tensor};
2422        let bits = HqqBits::Four;
2423        let device = Device::new_cuda(0).unwrap();
2424        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
2425        let c = bits.bitpack_type()(wq.clone())
2426            .unwrap()
2427            .to_vec2::<u8>()
2428            .unwrap();
2429        assert_eq!(c, [[19, 36]]);
2430    }
2431
2432    #[cfg(feature = "metal")]
2433    #[test]
2434    fn test_bitpack_4bit_metal() {
2435        use crate::HqqBits;
2436        use candle_core::{Device, Tensor};
2437        let bits = HqqBits::Four;
2438        let device = Device::new_metal(0).unwrap();
2439        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
2440        let c = bits.bitpack_type()(wq.clone())
2441            .unwrap()
2442            .to_vec2::<u8>()
2443            .unwrap();
2444        assert_eq!(c, [[19, 36]]);
2445    }
2446    // ─────────────────────────────── Sort / ArgSort ────────────────────────────────
2447    #[cfg(feature = "metal")]
2448    #[test]
2449    fn test_sort_and_argsort_vector_metal() {
2450        use crate::utils::ops::SortOp;
2451        use candle_core::Tensor;
2452
2453        let device = candle_core::Device::new_metal(0).unwrap();
2454        let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
2455
2456        // sort (ascending)
2457        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
2458        assert_eq!(sorted, [1, 2, 3, 4]);
2459
2460        // argsort (ascending indices)
2461        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
2462        assert_eq!(idx, [1, 3, 0, 2]);
2463    }
2464
2465    #[cfg(feature = "metal")]
2466    #[test]
2467    fn test_sort_and_argsort_matrix_axis1_metal() {
2468        use crate::utils::ops::SortOp;
2469        use candle_core::Tensor;
2470
2471        let device = candle_core::Device::new_metal(0).unwrap();
2472        // 2 × 3 matrix:
2473        // [[3, 1, 2],
2474        //  [0, 4, 5]]
2475        let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
2476
2477        // Sort along axis=1 (second dimension)
2478        let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
2479        assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
2480
2481        // ArgSort indices along axis=1
2482        let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
2483        assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
2484    }
2485
2486    // ─────────────────────────────── 2 048-element vector ────────────────────────────────
2487    #[cfg(feature = "metal")]
2488    #[test]
2489    fn test_sort_and_argsort_vector_2048_metal() {
2490        use crate::utils::ops::SortOp;
2491        use candle_core::Tensor;
2492
2493        const N: usize = 4096;
2494
2495        let device = candle_core::Device::new_metal(0).expect("Metal device");
2496
2497        // Create a descending vector [4095, 4094, …, 0]
2498        let vals: Vec<i32> = (0..N as i32).rev().collect();
2499        let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
2500
2501        // ---- sort (ascending) ---------------------------------------------------------
2502        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
2503        let expected: Vec<i32> = (0..N as i32).collect();
2504        assert_eq!(sorted, expected);
2505
2506        // ---- argsort (indices that would sort) ---------------------------------------
2507        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
2508        // Because the input is reversed, the correct indices are likewise reversed
2509        for (i, &v) in idx.iter().enumerate() {
2510            assert_eq!(v as usize, N - 1 - i);
2511        }
2512    }
2513}