mistralrs_core/
ops.rs

1use candle_core::{shape::Dim, DType, Result, Tensor, D};
2
3#[cfg(feature = "cuda")]
4use crate::cuda::ffi;
5
6#[allow(dead_code)]
7#[derive(Debug, Clone)]
8struct ArgSort {
9    asc: bool,
10    last_dim: usize,
11    inplace: bool,
12}
13
14impl candle_core::CustomOp1 for ArgSort {
15    fn name(&self) -> &'static str {
16        "argsort"
17    }
18
19    fn cpu_fwd(
20        &self,
21        _: &candle_core::CpuStorage,
22        _: &candle_core::Layout,
23    ) -> Result<(candle_core::CpuStorage, candle_core::Shape)> {
24        panic!("not implemented!")
25    }
26
27    #[allow(clippy::cast_possible_truncation)]
28    #[cfg(feature = "cuda")]
29    fn cuda_fwd(
30        &self,
31        storage: &candle_core::CudaStorage,
32        layout: &candle_core::Layout,
33    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
34        use candle_core::backend::BackendStorage;
35        use candle_core::cuda_backend::cudarc::driver::DevicePtr;
36        use candle_core::cuda_backend::CudaStorageSlice;
37        use candle_core::cuda_backend::WrapErr;
38        let dev = storage.device();
39        let elem_count = layout.shape().elem_count();
40        let ncols = self.last_dim as i32;
41        let nrows = elem_count as i32 / ncols;
42        let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
43
44        use std::ffi::c_void;
45
46        let src = match &storage.slice {
47            CudaStorageSlice::U8(inp) => inp.device_ptr(),
48            CudaStorageSlice::U32(inp) => inp.device_ptr(),
49            CudaStorageSlice::I64(inp) => inp.device_ptr(),
50            CudaStorageSlice::BF16(inp) => inp.device_ptr(),
51            CudaStorageSlice::F16(inp) => inp.device_ptr(),
52            CudaStorageSlice::F32(inp) => inp.device_ptr(),
53            CudaStorageSlice::F64(inp) => inp.device_ptr(),
54            _ => candle_core::bail!("Unexpected dtype in asort"),
55        };
56        let src_ptr = *src as *const c_void;
57        let dst_ptr = *dst.device_ptr() as *mut c_void;
58        let stream = *dev.cu_stream() as i64;
59        unsafe {
60            if self.asc {
61                match storage.dtype() {
62                    candle_core::DType::U8 => {
63                        ffi::asort_asc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
64                    }
65                    candle_core::DType::U32 => {
66                        ffi::asort_asc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
67                    }
68                    candle_core::DType::I64 => {
69                        ffi::asort_asc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
70                    }
71                    candle_core::DType::BF16 => {
72                        ffi::asort_asc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
73                    }
74                    candle_core::DType::F16 => {
75                        ffi::asort_asc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
76                    }
77                    candle_core::DType::F32 => {
78                        ffi::asort_asc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
79                    }
80                    candle_core::DType::F64 => {
81                        ffi::asort_asc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
82                    }
83                    _ => candle_core::bail!("Unexpected dtype in asort"),
84                }
85            } else {
86                match storage.dtype() {
87                    candle_core::DType::U8 => {
88                        ffi::asort_desc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
89                    }
90                    candle_core::DType::U32 => {
91                        ffi::asort_desc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
92                    }
93                    candle_core::DType::I64 => {
94                        ffi::asort_desc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
95                    }
96                    candle_core::DType::BF16 => {
97                        ffi::asort_desc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
98                    }
99                    candle_core::DType::F16 => {
100                        ffi::asort_desc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
101                    }
102                    candle_core::DType::F32 => {
103                        ffi::asort_desc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
104                    }
105                    candle_core::DType::F64 => {
106                        ffi::asort_desc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
107                    }
108                    _ => candle_core::bail!("Unexpected dtype in asort"),
109                }
110            }
111        }
112        let dst_ret = candle_core::cuda_backend::CudaStorage {
113            slice: CudaStorageSlice::U32(dst),
114            device: dev.clone(),
115        };
116        Ok((dst_ret, layout.shape().clone()))
117    }
118}
119
120#[allow(dead_code)]
121pub trait ArgSortOp {
122    fn arg_sort(&self, asc: bool) -> Result<Tensor>;
123    fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)>;
124}
125
126impl ArgSortOp for Tensor {
127    /// Returns the indices that sort the tensor along the last dimension.
128    ///
129    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
130    /// descending order. The sort is unstable so there is no guarantees on the final order when it
131    /// comes to ties.
132    fn arg_sort(&self, asc: bool) -> Result<Tensor> {
133        if !self.is_contiguous() {
134            return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
135        }
136        let last_dim = match self.dims().last() {
137            Some(last_dim) => *last_dim,
138            None => candle_core::bail!("empty last-dim in arg-sort"),
139        };
140        // No need for a backward pass for arg sort.
141        self.apply_op1_no_bwd(&ArgSort {
142            asc,
143            last_dim,
144            inplace: false,
145        })
146    }
147
148    /// Sorts the tensor along the last dimension, returns the sorted tensor together with the
149    /// sorted indexes.
150    ///
151    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
152    /// descending order. The sort is unstable so there is no guarantees on the final order when it
153    /// comes to ties.
154    fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)> {
155        if !self.is_contiguous() {
156            return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
157        }
158        let last_dim = match self.dims().last() {
159            Some(last_dim) => *last_dim,
160            None => candle_core::bail!("empty last-dim in arg-sort"),
161        };
162        let sorted = self.copy()?;
163
164        let asort = sorted.apply_op1_no_bwd(&ArgSort {
165            asc,
166            last_dim,
167            inplace: true,
168        })?;
169
170        Ok((sorted, asort))
171    }
172}
173
174#[allow(dead_code)]
175pub struct TopKOutput {
176    pub values: Tensor,
177    pub indices: Tensor,
178}
179
180pub trait TopKLastDimOp {
181    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
182    /// This expects a contiguous tensor.
183    /// Note: this implements torch.topk with sorted=True.
184    fn topk(&self, topk: usize) -> Result<TopKOutput>;
185
186    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
187    /// This expects a contiguous tensor.
188    /// Note: this implements torch.topk with sorted=False.
189    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
190}
191
192impl TopKLastDimOp for Tensor {
193    fn topk(&self, topk: usize) -> Result<TopKOutput> {
194        // Sorted descending
195        // #[cfg(feature = "cuda")]
196        // let (values, sorted_indices) = self.sort(false)?;
197        // #[cfg(not(feature = "cuda"))]
198        let (values, sorted_indices) = self.sort_last_dim(false)?;
199        let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
200        let topk_values = values.narrow(D::Minus1, 0, topk)?.contiguous()?;
201        Ok(TopKOutput {
202            values: topk_values,
203            indices: topk_indices,
204        })
205    }
206
207    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
208        // Sorted descending
209        let TopKOutput { values, indices } = self.topk(topk)?;
210        // Reorder the indices ascending
211        #[cfg(feature = "cuda")]
212        let reorder_indices = indices.arg_sort(true)?;
213        #[cfg(not(feature = "cuda"))]
214        let reorder_indices = indices.arg_sort_last_dim(true)?;
215        let topk_indices_unsorted = indices
216            .to_dtype(DType::F32)?
217            .gather(&reorder_indices, D::Minus1)?
218            .to_dtype(DType::U32)?;
219        let topk_values_unsorted = values.gather(&reorder_indices, D::Minus1)?;
220        Ok(TopKOutput {
221            values: topk_values_unsorted,
222            indices: topk_indices_unsorted,
223        })
224    }
225}
226
227pub trait RepeatInterleaveOp {
228    fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor>;
229    fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor>;
230}
231
232impl RepeatInterleaveOp for Tensor {
233    fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor> {
234        let dim = dim.to_index(self.shape(), "repeat_interleave")?;
235        let dim_elements = self.dim(dim)?;
236        // For metal
237        assert!(self.dtype().is_float());
238        #[allow(clippy::cast_possible_truncation)]
239        let indices = Tensor::new(
240            (0..dim_elements)
241                .flat_map(|i| vec![i as u32; repeats])
242                .collect::<Vec<_>>(),
243            self.device(),
244        )?;
245        self.index_select(&indices, dim)
246    }
247
248    fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor> {
249        let xs = self.flatten_all()?;
250        if repeats.len() != xs.dim(0)? {
251            candle_core::bail!(
252                "repeats ({}) must match flattened self length ({})",
253                repeats.len(),
254                xs.dim(0)?
255            );
256        }
257        #[allow(clippy::cast_possible_truncation)]
258        let indices = Tensor::new(
259            (0..xs.dim(0)?)
260                .flat_map(|i| vec![i as u32; repeats[i] as usize])
261                .collect::<Vec<_>>(),
262            xs.device(),
263        )?;
264        xs.index_select(&indices, 0)
265    }
266}
267
268pub trait SplitOp {
269    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
270}
271
272impl SplitOp for Tensor {
273    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
274        let dim = dim.to_index(self.shape(), "split")?;
275        let mut split_res = Vec::new();
276        let mut index = 0;
277        for split in splits {
278            split_res.push(self.narrow(dim, index, *split)?);
279            index += *split;
280        }
281        Ok(split_res)
282    }
283}
284
285#[allow(dead_code)]
286pub trait BincountOp {
287    fn bincount(&self, minlength: u32) -> Result<Vec<u32>>;
288}
289
290#[allow(dead_code)]
291fn bincount(values: &[u32], minlength: u32) -> Vec<u32> {
292    // let max_val = values.iter().max().copied().unwrap_or(0);
293    // let result_len = (max_val + 1).max(minlength);
294    // values.iter().fold(
295    //     // Start with a histogram vector of zeros.
296    //     vec![0u32; result_len as usize],
297    //     // For each value, update the histogram.
298    //     |mut histogram, &value| {
299    //         histogram[value as usize] += 1;
300    //         histogram
301    //     },
302    // )
303
304    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
305
306    // Early return if there are no values.
307    if values.is_empty() {
308        return vec![0u32; minlength as usize];
309    }
310
311    // Compute the maximum value in parallel.
312    // SAFETY: we know `values` is nonempty.
313    let max_val = *values.par_iter().max().unwrap();
314
315    // The histogram length must cover all observed values as well as `minlength`.
316    let result_len = (max_val + 1).max(minlength) as usize;
317
318    // Build per-thread histograms in parallel.
319    // We use unsafe indexing to eliminate bounds checks in the inner loop.
320    values
321        .par_iter()
322        .fold(
323            || vec![0u32; result_len],
324            |mut local_hist, &v| {
325                // SAFETY: v is guaranteed to be <= max_val, so it is in bounds.
326                unsafe {
327                    *local_hist.get_unchecked_mut(v as usize) += 1;
328                }
329                local_hist
330            },
331        )
332        // Merge the per-thread histograms in parallel.
333        .reduce(
334            || vec![0u32; result_len],
335            |mut global_hist, local_hist| {
336                for i in 0..result_len {
337                    // SAFETY: we know local histogram is at least result_len, as is global_hist
338                    unsafe {
339                        *global_hist.get_unchecked_mut(i) += local_hist.get_unchecked(i);
340                    }
341                }
342                global_hist
343            },
344        )
345}
346
347#[allow(dead_code)]
348impl BincountOp for Tensor {
349    fn bincount(&self, minlength: u32) -> Result<Vec<u32>> {
350        let values = self.to_vec1::<u32>()?;
351
352        Ok(bincount(&values, minlength))
353    }
354}
355
356mod tests {
357    #[test]
358    fn test_topk() {
359        use crate::ops::{TopKLastDimOp, TopKOutput};
360        use candle_core::Tensor;
361        let device = candle_core::Device::Cpu;
362        //  [[1, 3, 5],
363        //   [2, 4, 6]]
364        let x = Tensor::arange(1f32, 7f32, &device)
365            .unwrap()
366            .reshape((3, 2))
367            .unwrap()
368            .t()
369            .unwrap()
370            .contiguous()
371            .unwrap();
372        let TopKOutput { values, indices } = x.topk(2).unwrap();
373        assert_eq!(
374            x.to_vec2::<f32>().unwrap(),
375            vec![vec![1f32, 3f32, 5f32], vec![2f32, 4f32, 6f32]]
376        );
377        assert_eq!(
378            values.to_vec2::<f32>().unwrap(),
379            vec![vec![5f32, 3f32], vec![6f32, 4f32]]
380        );
381        assert_eq!(
382            indices.to_vec2::<u32>().unwrap(),
383            vec![vec![2u32, 1u32], vec![2u32, 1u32]]
384        );
385    }
386
387    #[test]
388    fn test_repeat_interleave() -> candle_core::Result<()> {
389        use crate::ops::RepeatInterleaveOp;
390        use candle_core::{Device, Tensor};
391
392        let input = Tensor::new(
393            vec![vec![vec![1f32, 2., 3.], vec![4f32, 5., 6.]]],
394            &Device::Cpu,
395        )?;
396
397        let repeat_interleaved = input.repeat_interleave(2, 2)?;
398        assert_eq!(
399            repeat_interleaved.to_vec3::<f32>()?,
400            vec![vec![
401                vec![1., 1., 2., 2., 3., 3.],
402                vec![4., 4., 5., 5., 6., 6.]
403            ]]
404        );
405
406        Ok(())
407    }
408
409    #[test]
410    fn test_repeat_interleave_flat() -> candle_core::Result<()> {
411        use crate::ops::RepeatInterleaveOp;
412        use candle_core::{Device, Tensor};
413
414        let input = Tensor::new(vec![1., 2., 3., 4.], &Device::Cpu)?;
415
416        let repeat_interleaved = input.repeat_interleave_flat(vec![1u32, 2u32, 3u32, 4u32])?;
417        assert_eq!(
418            repeat_interleaved.to_vec1::<f64>()?,
419            vec![1., 2., 2., 3., 3., 3., 4., 4., 4., 4.]
420        );
421
422        Ok(())
423    }
424}