1use candle_core::{shape::Dim, DType, Result, Tensor, D};
2
3#[cfg(feature = "cuda")]
4use crate::cuda::ffi;
5use crate::layers::Activation;
6
7#[allow(dead_code)]
8#[derive(Debug, Clone)]
9struct ArgSort {
10 asc: bool,
11 last_dim: usize,
12 inplace: bool,
13}
14
15impl candle_core::CustomOp1 for ArgSort {
16 fn name(&self) -> &'static str {
17 "argsort"
18 }
19
20 fn cpu_fwd(
21 &self,
22 _: &candle_core::CpuStorage,
23 _: &candle_core::Layout,
24 ) -> Result<(candle_core::CpuStorage, candle_core::Shape)> {
25 panic!("not implemented!")
26 }
27
28 #[allow(clippy::cast_possible_truncation)]
29 #[cfg(feature = "cuda")]
30 fn cuda_fwd(
31 &self,
32 storage: &candle_core::CudaStorage,
33 layout: &candle_core::Layout,
34 ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
35 use candle_core::backend::BackendStorage;
36 use candle_core::cuda_backend::cudarc::driver::DevicePtr;
37 use candle_core::cuda_backend::CudaStorageSlice;
38
39 let dev = storage.device();
40 let elem_count = layout.shape().elem_count();
41 let ncols = self.last_dim as i32;
42 let nrows = elem_count as i32 / ncols;
43 let dst = unsafe { dev.alloc::<u32>(elem_count) }?;
44
45 use std::ffi::c_void;
46
47 let (src, _src_guard) = match &storage.slice {
48 CudaStorageSlice::U8(inp) => inp.device_ptr(inp.stream()),
49 CudaStorageSlice::U32(inp) => inp.device_ptr(inp.stream()),
50 CudaStorageSlice::I64(inp) => inp.device_ptr(inp.stream()),
51 CudaStorageSlice::BF16(inp) => inp.device_ptr(inp.stream()),
52 CudaStorageSlice::F16(inp) => inp.device_ptr(inp.stream()),
53 CudaStorageSlice::F32(inp) => inp.device_ptr(inp.stream()),
54 CudaStorageSlice::F64(inp) => inp.device_ptr(inp.stream()),
55 _ => candle_core::bail!("Unexpected dtype in asort"),
56 };
57 let src_ptr = src as *const c_void;
58 let (dst_ptr, dst_guard) = dst.device_ptr(dst.stream());
59 let dst_ptr = dst_ptr as *mut c_void;
60 let stream = dev.cuda_stream().cu_stream() as i64;
61 unsafe {
62 if self.asc {
63 match storage.dtype() {
64 candle_core::DType::U8 => {
65 ffi::asort_asc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
66 }
67 candle_core::DType::U32 => {
68 ffi::asort_asc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
69 }
70 candle_core::DType::I64 => {
71 ffi::asort_asc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
72 }
73 candle_core::DType::BF16 => {
74 ffi::asort_asc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
75 }
76 candle_core::DType::F16 => {
77 ffi::asort_asc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
78 }
79 candle_core::DType::F32 => {
80 ffi::asort_asc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
81 }
82 candle_core::DType::F64 => {
83 ffi::asort_asc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
84 }
85 _ => candle_core::bail!("Unexpected dtype in asort"),
86 }
87 } else {
88 match storage.dtype() {
89 candle_core::DType::U8 => {
90 ffi::asort_desc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
91 }
92 candle_core::DType::U32 => {
93 ffi::asort_desc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
94 }
95 candle_core::DType::I64 => {
96 ffi::asort_desc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
97 }
98 candle_core::DType::BF16 => {
99 ffi::asort_desc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
100 }
101 candle_core::DType::F16 => {
102 ffi::asort_desc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
103 }
104 candle_core::DType::F32 => {
105 ffi::asort_desc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
106 }
107 candle_core::DType::F64 => {
108 ffi::asort_desc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
109 }
110 _ => candle_core::bail!("Unexpected dtype in asort"),
111 }
112 }
113 }
114 drop(dst_guard);
115 let dst_ret = candle_core::cuda_backend::CudaStorage {
116 slice: CudaStorageSlice::U32(dst),
117 device: dev.clone(),
118 };
119 Ok((dst_ret, layout.shape().clone()))
120 }
121}
122
123#[allow(dead_code)]
124pub trait ArgSortOp {
125 fn arg_sort(&self, asc: bool) -> Result<Tensor>;
126 fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)>;
127}
128
129impl ArgSortOp for Tensor {
130 fn arg_sort(&self, asc: bool) -> Result<Tensor> {
136 if !self.is_contiguous() {
137 return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
138 }
139 let last_dim = match self.dims().last() {
140 Some(last_dim) => *last_dim,
141 None => candle_core::bail!("empty last-dim in arg-sort"),
142 };
143 self.apply_op1_no_bwd(&ArgSort {
145 asc,
146 last_dim,
147 inplace: false,
148 })
149 }
150
151 fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)> {
158 if !self.is_contiguous() {
159 return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
160 }
161 let last_dim = match self.dims().last() {
162 Some(last_dim) => *last_dim,
163 None => candle_core::bail!("empty last-dim in arg-sort"),
164 };
165 let sorted = self.copy()?;
166
167 let asort = sorted.apply_op1_no_bwd(&ArgSort {
168 asc,
169 last_dim,
170 inplace: true,
171 })?;
172
173 Ok((sorted, asort))
174 }
175}
176
177#[allow(dead_code)]
178pub struct TopKOutput {
179 pub values: Tensor,
180 pub indices: Tensor,
181}
182
183pub trait TopKLastDimOp {
184 fn topk(&self, topk: usize) -> Result<TopKOutput>;
188
189 fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
193}
194
195impl TopKLastDimOp for Tensor {
196 fn topk(&self, topk: usize) -> Result<TopKOutput> {
197 let (values, sorted_indices) = if self.device().is_cuda() {
199 self.sort(false)?
200 } else {
201 self.sort_last_dim(false)?
202 };
203 let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
204 let topk_values = values.narrow(D::Minus1, 0, topk)?.contiguous()?;
205 Ok(TopKOutput {
206 values: topk_values,
207 indices: topk_indices,
208 })
209 }
210
211 fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
212 let TopKOutput { values, indices } = self.topk(topk)?;
214 #[cfg(feature = "cuda")]
216 let reorder_indices = indices.arg_sort(true)?;
217 #[cfg(not(feature = "cuda"))]
218 let reorder_indices = indices.arg_sort_last_dim(true)?;
219 let topk_indices_unsorted = indices
220 .to_dtype(DType::F32)?
221 .gather(&reorder_indices, D::Minus1)?
222 .to_dtype(DType::U32)?;
223 let topk_values_unsorted = values.gather(&reorder_indices, D::Minus1)?;
224 Ok(TopKOutput {
225 values: topk_values_unsorted,
226 indices: topk_indices_unsorted,
227 })
228 }
229}
230
231pub trait RepeatInterleaveOp {
232 fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor>;
233 fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor>;
234}
235
236impl RepeatInterleaveOp for Tensor {
237 fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor> {
238 let dim = dim.to_index(self.shape(), "repeat_interleave")?;
239 let dim_elements = self.dim(dim)?;
240 assert!(self.dtype().is_float());
242 #[allow(clippy::cast_possible_truncation)]
243 let indices = Tensor::new(
244 (0..dim_elements)
245 .flat_map(|i| vec![i as u32; repeats])
246 .collect::<Vec<_>>(),
247 self.device(),
248 )?;
249 self.index_select(&indices, dim)
250 }
251
252 fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor> {
253 let xs = self.flatten_all()?;
254 if repeats.len() != xs.dim(0)? {
255 candle_core::bail!(
256 "repeats ({}) must match flattened self length ({})",
257 repeats.len(),
258 xs.dim(0)?
259 );
260 }
261 #[allow(clippy::cast_possible_truncation)]
262 let indices = Tensor::new(
263 (0..xs.dim(0)?)
264 .flat_map(|i| vec![i as u32; repeats[i] as usize])
265 .collect::<Vec<_>>(),
266 xs.device(),
267 )?;
268 xs.index_select(&indices, 0)
269 }
270}
271
272pub trait SplitOp {
273 fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
274}
275
276impl SplitOp for Tensor {
277 fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
278 let dim = dim.to_index(self.shape(), "split")?;
279 let mut split_res = Vec::new();
280 let mut index = 0;
281 for split in splits {
282 split_res.push(self.narrow(dim, index, *split)?);
283 index += *split;
284 }
285 Ok(split_res)
286 }
287}
288
289#[allow(dead_code)]
290pub trait BincountOp {
291 fn bincount(&self, minlength: u32) -> Result<Vec<u32>>;
292}
293
294#[allow(dead_code)]
295fn bincount(values: &[u32], minlength: u32) -> Vec<u32> {
296 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
309
310 if values.is_empty() {
312 return vec![0u32; minlength as usize];
313 }
314
315 let max_val = *values
319 .par_iter()
320 .max()
321 .expect("values should be non-empty after empty check");
322
323 let result_len = (max_val + 1).max(minlength) as usize;
325
326 values
329 .par_iter()
330 .fold(
331 || vec![0u32; result_len],
332 |mut local_hist, &v| {
333 unsafe {
335 *local_hist.get_unchecked_mut(v as usize) += 1;
336 }
337 local_hist
338 },
339 )
340 .reduce(
342 || vec![0u32; result_len],
343 |mut global_hist, local_hist| {
344 for i in 0..result_len {
345 unsafe {
347 *global_hist.get_unchecked_mut(i) += local_hist.get_unchecked(i);
348 }
349 }
350 global_hist
351 },
352 )
353}
354
355#[allow(dead_code)]
356impl BincountOp for Tensor {
357 fn bincount(&self, minlength: u32) -> Result<Vec<u32>> {
358 let values = self.to_vec1::<u32>()?;
359
360 Ok(bincount(&values, minlength))
361 }
362}
363
364pub fn apply_triangular(xs: &Tensor, diagonal: isize, upper: bool) -> Result<Tensor> {
366 let device = xs.device();
367 let (l, s) = xs.dims2()?;
368 let mut xs_tri = vec![];
369 for i in 0..l as isize {
370 for j in 0..s as isize {
371 let cond = if upper {
372 i + diagonal > j
373 } else {
374 i + diagonal < j
375 };
376 xs_tri.push(if cond { 0u8 } else { 1u8 });
377 }
378 }
379 xs * Tensor::from_vec(xs_tri, (l, s), device)?.to_dtype(xs.dtype())?
380}
381
382pub fn mul_and_act(a: &Tensor, b: &Tensor, act: Activation) -> Result<Tensor> {
390 a.apply(&act)? * b
391}
392
393mod tests {
394 #[test]
395 fn test_topk() {
396 use crate::ops::{TopKLastDimOp, TopKOutput};
397 use candle_core::Tensor;
398 let device = candle_core::Device::Cpu;
399 let x = Tensor::arange(1f32, 7f32, &device)
402 .unwrap()
403 .reshape((3, 2))
404 .unwrap()
405 .t()
406 .unwrap()
407 .contiguous()
408 .unwrap();
409 let TopKOutput { values, indices } = x.topk(2).unwrap();
410 assert_eq!(
411 x.to_vec2::<f32>().unwrap(),
412 vec![vec![1f32, 3f32, 5f32], vec![2f32, 4f32, 6f32]]
413 );
414 assert_eq!(
415 values.to_vec2::<f32>().unwrap(),
416 vec![vec![5f32, 3f32], vec![6f32, 4f32]]
417 );
418 assert_eq!(
419 indices.to_vec2::<u32>().unwrap(),
420 vec![vec![2u32, 1u32], vec![2u32, 1u32]]
421 );
422 }
423
424 #[test]
425 fn test_repeat_interleave() -> candle_core::Result<()> {
426 use crate::ops::RepeatInterleaveOp;
427 use candle_core::{Device, Tensor};
428
429 let input = Tensor::new(
430 vec![vec![vec![1f32, 2., 3.], vec![4f32, 5., 6.]]],
431 &Device::Cpu,
432 )?;
433
434 let repeat_interleaved = input.repeat_interleave(2, 2)?;
435 assert_eq!(
436 repeat_interleaved.to_vec3::<f32>()?,
437 vec![vec![
438 vec![1., 1., 2., 2., 3., 3.],
439 vec![4., 4., 5., 5., 6., 6.]
440 ]]
441 );
442
443 Ok(())
444 }
445
446 #[test]
447 fn test_repeat_interleave_flat() -> candle_core::Result<()> {
448 use crate::ops::RepeatInterleaveOp;
449 use candle_core::{Device, Tensor};
450
451 let input = Tensor::new(vec![1., 2., 3., 4.], &Device::Cpu)?;
452
453 let repeat_interleaved = input.repeat_interleave_flat(vec![1u32, 2u32, 3u32, 4u32])?;
454 assert_eq!(
455 repeat_interleaved.to_vec1::<f64>()?,
456 vec![1., 2., 2., 3., 3., 3., 4., 4., 4., 4.]
457 );
458
459 Ok(())
460 }
461}