diffusion_rs_common/core/
sort.rsuse crate::core::{Result, Tensor};
use rayon::prelude::*;
#[derive(Debug, Clone, Copy)]
struct ArgSort {
    asc: bool,
    last_dim: usize,
}
impl ArgSort {
    fn asort<T: crate::core::WithDType>(&self, vs: &[T], layout: &crate::core::Layout) -> Vec<u32> {
        #[allow(clippy::uninit_vec)]
        let mut sort_indexes = unsafe {
            let el_count = layout.shape().elem_count();
            let mut v = Vec::with_capacity(el_count);
            v.set_len(el_count);
            v
        };
        if self.asc {
            sort_indexes
                .par_chunks_exact_mut(self.last_dim)
                .zip(vs.par_chunks_exact(self.last_dim))
                .for_each(|(indexes, vs)| {
                    indexes
                        .iter_mut()
                        .enumerate()
                        .for_each(|(i, v)| *v = i as u32);
                    indexes.sort_by(|&i, &j| {
                        vs[i as usize]
                            .partial_cmp(&vs[j as usize])
                            .unwrap_or(std::cmp::Ordering::Greater)
                    })
                });
        } else {
            sort_indexes
                .par_chunks_exact_mut(self.last_dim)
                .zip(vs.par_chunks_exact(self.last_dim))
                .for_each(|(indexes, vs)| {
                    indexes
                        .iter_mut()
                        .enumerate()
                        .for_each(|(i, v)| *v = i as u32);
                    indexes.sort_by(|&j, &i| {
                        vs[i as usize]
                            .partial_cmp(&vs[j as usize])
                            .unwrap_or(std::cmp::Ordering::Greater)
                    })
                });
        }
        sort_indexes
    }
}
impl crate::core::CustomOp1 for ArgSort {
    fn name(&self) -> &'static str {
        "argsort"
    }
    fn cpu_fwd(
        &self,
        storage: &crate::core::CpuStorage,
        layout: &crate::core::Layout,
    ) -> Result<(crate::core::CpuStorage, crate::core::Shape)> {
        let sort_indexes = match storage {
            crate::core::CpuStorage::U8(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::I8(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::U32(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::I16(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::I32(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::I64(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::BF16(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::F16(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::F32(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::F64(vs) => self.asort(vs, layout),
            crate::core::CpuStorage::F8E4M3(vs) => self.asort(vs, layout),
        };
        let sort_indexes = crate::core::CpuStorage::U32(sort_indexes);
        Ok((sort_indexes, layout.shape().into()))
    }
    #[cfg(feature = "cuda")]
    fn cuda_fwd(
        &self,
        storage: &crate::core::CudaStorage,
        layout: &crate::core::Layout,
    ) -> Result<(crate::core::CudaStorage, crate::core::Shape)> {
        use crate::core::cuda_backend::cudarc::driver::{
            CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
        };
        use crate::core::cuda_backend::{
            kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr,
        };
        use crate::core::{CudaDevice, WithDType};
        #[allow(non_local_definitions)]
        impl Map1Any for ArgSort {
            fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
                &self,
                src: &CudaSlice<T>,
                dev: &CudaDevice,
                layout: &crate::core::Layout,
                _wrap: W,
            ) -> Result<S> {
                let slice = match layout.contiguous_offsets() {
                    None => crate::bail!("input has to be contiguous"),
                    Some((o1, o2)) => src.slice(o1..o2),
                };
                let elem_count = layout.shape().elem_count();
                let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
                let func = if self.asc {
                    dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
                } else {
                    dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
                };
                let ncols = self.last_dim;
                let nrows = elem_count / ncols;
                let ncols_pad = next_power_of_2(ncols);
                let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
                let cfg = LaunchConfig {
                    grid_dim: (1, nrows as u32, 1),
                    block_dim: (ncols_pad as u32, 1, 1),
                    shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
                };
                unsafe { func.launch(cfg, params) }.w()?;
                Ok(S::U32(dst))
            }
        }
        use crate::core::backend::BackendStorage;
        let dev = storage.device();
        let slice = self.map(&storage.slice, dev, layout)?;
        let dst = crate::core::cuda_backend::CudaStorage {
            slice,
            device: dev.clone(),
        };
        Ok((dst, layout.shape().clone()))
    }
    #[cfg(feature = "metal")]
    fn metal_fwd(
        &self,
        storage: &crate::core::MetalStorage,
        layout: &crate::core::Layout,
    ) -> Result<(crate::core::MetalStorage, crate::core::Shape)> {
        use crate::core::backend::BackendStorage;
        use crate::core::DType;
        let name = {
            if self.asc {
                match storage.dtype() {
                    DType::BF16 => "asort_asc_bf16",
                    DType::F16 => "asort_asc_f16",
                    DType::F32 => "asort_asc_f32",
                    DType::F64 => "asort_asc_f64",
                    DType::I8 => "asort_asc_i8",
                    DType::U8 => "asort_asc_u8",
                    DType::U32 => "asort_asc_u32",
                    DType::I64 => "asort_asc_i64",
                    DType::I32 => "asort_asc_i32",
                    DType::I16 => "asort_asc_i16",
                    DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
                }
            } else {
                match storage.dtype() {
                    DType::BF16 => "asort_desc_bf16",
                    DType::F16 => "asort_desc_f16",
                    DType::F32 => "asort_desc_f32",
                    DType::F64 => "asort_desc_f64",
                    DType::I8 => "asort_desc_i8",
                    DType::U8 => "asort_desc_u8",
                    DType::U32 => "asort_desc_u32",
                    DType::I64 => "asort_desc_i64",
                    DType::I32 => "asort_desc_i32",
                    DType::I16 => "asort_desc_i16",
                    DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
                }
            }
        };
        let device = storage.device();
        let kernels = device.kernels();
        let command_buffer = device.command_buffer()?;
        let el = layout.shape().elem_count();
        let ncols = self.last_dim;
        let nrows = el / ncols;
        let src = crate::core::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
        let dst = device.new_buffer(el, DType::U32, "asort")?;
        let mut ncols_pad = 1;
        while ncols_pad < ncols {
            ncols_pad *= 2;
        }
        crate::metal_kernels::call_arg_sort(
            device.metal_device(),
            &command_buffer,
            kernels,
            name,
            nrows,
            ncols,
            ncols_pad,
            src,
            &dst,
        )
        .map_err(crate::core::Error::wrap)?;
        let dst = crate::core::MetalStorage::new(dst, device.clone(), el, DType::U32);
        Ok((dst, layout.shape().clone()))
    }
}
#[allow(unused)]
fn next_power_of_2(x: usize) -> usize {
    let mut n = 1;
    while n < x {
        n *= 2
    }
    n
}
impl Tensor {
    pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
        if !self.is_contiguous() {
            return Err(crate::core::Error::RequiresContiguous {
                op: "arg_sort_last_dim",
            });
        }
        let last_dim = match self.dims().last() {
            None => crate::bail!("empty last-dim in arg-sort"),
            Some(last_dim) => *last_dim,
        };
        self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
    }
    pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
        if !self.is_contiguous() {
            return Err(crate::core::Error::RequiresContiguous {
                op: "sort_last_dim",
            });
        }
        let asort = self.arg_sort_last_dim(asc)?;
        let sorted = self.gather(&asort, crate::core::D::Minus1)?;
        Ok((sorted, asort))
    }
}