1use candle_core::{shape::Dim, DType, Result, Tensor, D};
2
3#[cfg(feature = "cuda")]
4use crate::cuda::ffi;
5use crate::layers::Activation;
6
7#[allow(dead_code)]
8#[derive(Debug, Clone)]
9struct ArgSort {
10 asc: bool,
11 last_dim: usize,
12 inplace: bool,
13}
14
15impl candle_core::CustomOp1 for ArgSort {
16 fn name(&self) -> &'static str {
17 "argsort"
18 }
19
20 fn cpu_fwd(
21 &self,
22 _: &candle_core::CpuStorage,
23 _: &candle_core::Layout,
24 ) -> Result<(candle_core::CpuStorage, candle_core::Shape)> {
25 panic!("not implemented!")
26 }
27
28 #[allow(clippy::cast_possible_truncation)]
29 #[cfg(feature = "cuda")]
30 fn cuda_fwd(
31 &self,
32 storage: &candle_core::CudaStorage,
33 layout: &candle_core::Layout,
34 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
35 use candle_core::backend::BackendStorage;
36 use candle_core::cuda_backend::cudarc::driver::DevicePtr;
37 use candle_core::cuda_backend::CudaStorageSlice;
38
39 let dev = storage.device();
40 let elem_count = layout.shape().elem_count();
41 let ncols = self.last_dim as i32;
42 let nrows = elem_count as i32 / ncols;
43 let dst = unsafe { dev.alloc::<u32>(elem_count) }?;
44
45 use std::ffi::c_void;
46
47 let (src, _src_guard) = match &storage.slice {
48 CudaStorageSlice::U8(inp) => inp.device_ptr(inp.stream()),
49 CudaStorageSlice::U32(inp) => inp.device_ptr(inp.stream()),
50 CudaStorageSlice::I64(inp) => inp.device_ptr(inp.stream()),
51 CudaStorageSlice::BF16(inp) => inp.device_ptr(inp.stream()),
52 CudaStorageSlice::F16(inp) => inp.device_ptr(inp.stream()),
53 CudaStorageSlice::F32(inp) => inp.device_ptr(inp.stream()),
54 CudaStorageSlice::F64(inp) => inp.device_ptr(inp.stream()),
55 _ => candle_core::bail!("Unexpected dtype in asort"),
56 };
57 let src_ptr = src as *const c_void;
58 let (dst_ptr, dst_guard) = dst.device_ptr(dst.stream());
59 let dst_ptr = dst_ptr as *mut c_void;
60 let stream = dev.cuda_stream().cu_stream() as i64;
61 unsafe {
62 if self.asc {
63 match storage.dtype() {
64 candle_core::DType::U8 => {
65 ffi::asort_asc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
66 }
67 candle_core::DType::U32 => {
68 ffi::asort_asc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
69 }
70 candle_core::DType::I64 => {
71 ffi::asort_asc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
72 }
73 candle_core::DType::BF16 => {
74 ffi::asort_asc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
75 }
76 candle_core::DType::F16 => {
77 ffi::asort_asc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
78 }
79 candle_core::DType::F32 => {
80 ffi::asort_asc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
81 }
82 candle_core::DType::F64 => {
83 ffi::asort_asc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
84 }
85 _ => candle_core::bail!("Unexpected dtype in asort"),
86 }
87 } else {
88 match storage.dtype() {
89 candle_core::DType::U8 => {
90 ffi::asort_desc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
91 }
92 candle_core::DType::U32 => {
93 ffi::asort_desc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
94 }
95 candle_core::DType::I64 => {
96 ffi::asort_desc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
97 }
98 candle_core::DType::BF16 => {
99 ffi::asort_desc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
100 }
101 candle_core::DType::F16 => {
102 ffi::asort_desc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
103 }
104 candle_core::DType::F32 => {
105 ffi::asort_desc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
106 }
107 candle_core::DType::F64 => {
108 ffi::asort_desc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
109 }
110 _ => candle_core::bail!("Unexpected dtype in asort"),
111 }
112 }
113 }
114 drop(dst_guard);
115 let dst_ret = candle_core::cuda_backend::CudaStorage {
116 slice: CudaStorageSlice::U32(dst),
117 device: dev.clone(),
118 };
119 Ok((dst_ret, layout.shape().clone()))
120 }
121}
122
123#[allow(dead_code)]
124pub trait ArgSortOp {
125 fn arg_sort(&self, asc: bool) -> Result<Tensor>;
126 fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)>;
127}
128
129impl ArgSortOp for Tensor {
130 fn arg_sort(&self, asc: bool) -> Result<Tensor> {
136 if !self.is_contiguous() {
137 return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
138 }
139 let last_dim = match self.dims().last() {
140 Some(last_dim) => *last_dim,
141 None => candle_core::bail!("empty last-dim in arg-sort"),
142 };
143 self.apply_op1_no_bwd(&ArgSort {
145 asc,
146 last_dim,
147 inplace: false,
148 })
149 }
150
151 fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)> {
158 if !self.is_contiguous() {
159 return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
160 }
161 let last_dim = match self.dims().last() {
162 Some(last_dim) => *last_dim,
163 None => candle_core::bail!("empty last-dim in arg-sort"),
164 };
165 let sorted = self.copy()?;
166
167 let asort = sorted.apply_op1_no_bwd(&ArgSort {
168 asc,
169 last_dim,
170 inplace: true,
171 })?;
172
173 Ok((sorted, asort))
174 }
175}
176
177#[allow(dead_code)]
178pub struct TopKOutput {
179 pub values: Tensor,
180 pub indices: Tensor,
181}
182
183pub trait TopKLastDimOp {
184 fn topk(&self, topk: usize) -> Result<TopKOutput>;
188
189 fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
193}
194
195impl TopKLastDimOp for Tensor {
196 fn topk(&self, topk: usize) -> Result<TopKOutput> {
197 let (values, sorted_indices) = self.sort_last_dim(false)?;
202 let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
203 let topk_values = values.narrow(D::Minus1, 0, topk)?.contiguous()?;
204 Ok(TopKOutput {
205 values: topk_values,
206 indices: topk_indices,
207 })
208 }
209
210 fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
211 let TopKOutput { values, indices } = self.topk(topk)?;
213 #[cfg(feature = "cuda")]
215 let reorder_indices = indices.arg_sort(true)?;
216 #[cfg(not(feature = "cuda"))]
217 let reorder_indices = indices.arg_sort_last_dim(true)?;
218 let topk_indices_unsorted = indices
219 .to_dtype(DType::F32)?
220 .gather(&reorder_indices, D::Minus1)?
221 .to_dtype(DType::U32)?;
222 let topk_values_unsorted = values.gather(&reorder_indices, D::Minus1)?;
223 Ok(TopKOutput {
224 values: topk_values_unsorted,
225 indices: topk_indices_unsorted,
226 })
227 }
228}
229
230pub trait RepeatInterleaveOp {
231 fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor>;
232 fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor>;
233}
234
235impl RepeatInterleaveOp for Tensor {
236 fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor> {
237 let dim = dim.to_index(self.shape(), "repeat_interleave")?;
238 let dim_elements = self.dim(dim)?;
239 assert!(self.dtype().is_float());
241 #[allow(clippy::cast_possible_truncation)]
242 let indices = Tensor::new(
243 (0..dim_elements)
244 .flat_map(|i| vec![i as u32; repeats])
245 .collect::<Vec<_>>(),
246 self.device(),
247 )?;
248 self.index_select(&indices, dim)
249 }
250
251 fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor> {
252 let xs = self.flatten_all()?;
253 if repeats.len() != xs.dim(0)? {
254 candle_core::bail!(
255 "repeats ({}) must match flattened self length ({})",
256 repeats.len(),
257 xs.dim(0)?
258 );
259 }
260 #[allow(clippy::cast_possible_truncation)]
261 let indices = Tensor::new(
262 (0..xs.dim(0)?)
263 .flat_map(|i| vec![i as u32; repeats[i] as usize])
264 .collect::<Vec<_>>(),
265 xs.device(),
266 )?;
267 xs.index_select(&indices, 0)
268 }
269}
270
271pub trait SplitOp {
272 fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
273}
274
275impl SplitOp for Tensor {
276 fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
277 let dim = dim.to_index(self.shape(), "split")?;
278 let mut split_res = Vec::new();
279 let mut index = 0;
280 for split in splits {
281 split_res.push(self.narrow(dim, index, *split)?);
282 index += *split;
283 }
284 Ok(split_res)
285 }
286}
287
288#[allow(dead_code)]
289pub trait BincountOp {
290 fn bincount(&self, minlength: u32) -> Result<Vec<u32>>;
291}
292
293#[allow(dead_code)]
294fn bincount(values: &[u32], minlength: u32) -> Vec<u32> {
295 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
308
309 if values.is_empty() {
311 return vec![0u32; minlength as usize];
312 }
313
314 let max_val = *values.par_iter().max().unwrap();
317
318 let result_len = (max_val + 1).max(minlength) as usize;
320
321 values
324 .par_iter()
325 .fold(
326 || vec![0u32; result_len],
327 |mut local_hist, &v| {
328 unsafe {
330 *local_hist.get_unchecked_mut(v as usize) += 1;
331 }
332 local_hist
333 },
334 )
335 .reduce(
337 || vec![0u32; result_len],
338 |mut global_hist, local_hist| {
339 for i in 0..result_len {
340 unsafe {
342 *global_hist.get_unchecked_mut(i) += local_hist.get_unchecked(i);
343 }
344 }
345 global_hist
346 },
347 )
348}
349
350#[allow(dead_code)]
351impl BincountOp for Tensor {
352 fn bincount(&self, minlength: u32) -> Result<Vec<u32>> {
353 let values = self.to_vec1::<u32>()?;
354
355 Ok(bincount(&values, minlength))
356 }
357}
358
359pub fn apply_triangular(xs: &Tensor, diagonal: isize, upper: bool) -> Result<Tensor> {
361 let device = xs.device();
362 let (l, s) = xs.dims2()?;
363 let mut xs_tri = vec![];
364 for i in 0..l as isize {
365 for j in 0..s as isize {
366 let cond = if upper {
367 i + diagonal > j
368 } else {
369 i + diagonal < j
370 };
371 xs_tri.push(if cond { 0u8 } else { 1u8 });
372 }
373 }
374 xs * Tensor::from_vec(xs_tri, (l, s), device)?.to_dtype(xs.dtype())?
375}
376
377pub fn mul_and_act(a: &Tensor, b: &Tensor, act: Activation) -> Result<Tensor> {
385 a.apply(&act)? * b
386}
387
388mod tests {
389 #[test]
390 fn test_topk() {
391 use crate::ops::{TopKLastDimOp, TopKOutput};
392 use candle_core::Tensor;
393 let device = candle_core::Device::Cpu;
394 let x = Tensor::arange(1f32, 7f32, &device)
397 .unwrap()
398 .reshape((3, 2))
399 .unwrap()
400 .t()
401 .unwrap()
402 .contiguous()
403 .unwrap();
404 let TopKOutput { values, indices } = x.topk(2).unwrap();
405 assert_eq!(
406 x.to_vec2::<f32>().unwrap(),
407 vec![vec![1f32, 3f32, 5f32], vec![2f32, 4f32, 6f32]]
408 );
409 assert_eq!(
410 values.to_vec2::<f32>().unwrap(),
411 vec![vec![5f32, 3f32], vec![6f32, 4f32]]
412 );
413 assert_eq!(
414 indices.to_vec2::<u32>().unwrap(),
415 vec![vec![2u32, 1u32], vec![2u32, 1u32]]
416 );
417 }
418
419 #[test]
420 fn test_repeat_interleave() -> candle_core::Result<()> {
421 use crate::ops::RepeatInterleaveOp;
422 use candle_core::{Device, Tensor};
423
424 let input = Tensor::new(
425 vec![vec![vec![1f32, 2., 3.], vec![4f32, 5., 6.]]],
426 &Device::Cpu,
427 )?;
428
429 let repeat_interleaved = input.repeat_interleave(2, 2)?;
430 assert_eq!(
431 repeat_interleaved.to_vec3::<f32>()?,
432 vec![vec![
433 vec![1., 1., 2., 2., 3., 3.],
434 vec![4., 4., 5., 5., 6., 6.]
435 ]]
436 );
437
438 Ok(())
439 }
440
441 #[test]
442 fn test_repeat_interleave_flat() -> candle_core::Result<()> {
443 use crate::ops::RepeatInterleaveOp;
444 use candle_core::{Device, Tensor};
445
446 let input = Tensor::new(vec![1., 2., 3., 4.], &Device::Cpu)?;
447
448 let repeat_interleaved = input.repeat_interleave_flat(vec![1u32, 2u32, 3u32, 4u32])?;
449 assert_eq!(
450 repeat_interleaved.to_vec1::<f64>()?,
451 vec![1., 2., 2., 3., 3., 3., 4., 4., 4., 4.]
452 );
453
454 Ok(())
455 }
456}