mistralrs_core/
ops.rs

1use candle_core::{shape::Dim, DType, Result, Tensor, D};
2
3#[cfg(feature = "cuda")]
4use crate::cuda::ffi;
5use crate::layers::Activation;
6
7#[allow(dead_code)]
8#[derive(Debug, Clone)]
9struct ArgSort {
10    asc: bool,
11    last_dim: usize,
12    inplace: bool,
13}
14
15impl candle_core::CustomOp1 for ArgSort {
16    fn name(&self) -> &'static str {
17        "argsort"
18    }
19
20    fn cpu_fwd(
21        &self,
22        _: &candle_core::CpuStorage,
23        _: &candle_core::Layout,
24    ) -> Result<(candle_core::CpuStorage, candle_core::Shape)> {
25        panic!("not implemented!")
26    }
27
28    #[allow(clippy::cast_possible_truncation)]
29    #[cfg(feature = "cuda")]
30    fn cuda_fwd(
31        &self,
32        storage: &candle_core::CudaStorage,
33        layout: &candle_core::Layout,
34    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
35        use candle_core::backend::BackendStorage;
36        use candle_core::cuda_backend::cudarc::driver::DevicePtr;
37        use candle_core::cuda_backend::CudaStorageSlice;
38
39        let dev = storage.device();
40        let elem_count = layout.shape().elem_count();
41        let ncols = self.last_dim as i32;
42        let nrows = elem_count as i32 / ncols;
43        let dst = unsafe { dev.alloc::<u32>(elem_count) }?;
44
45        use std::ffi::c_void;
46
47        let (src, _src_guard) = match &storage.slice {
48            CudaStorageSlice::U8(inp) => inp.device_ptr(inp.stream()),
49            CudaStorageSlice::U32(inp) => inp.device_ptr(inp.stream()),
50            CudaStorageSlice::I64(inp) => inp.device_ptr(inp.stream()),
51            CudaStorageSlice::BF16(inp) => inp.device_ptr(inp.stream()),
52            CudaStorageSlice::F16(inp) => inp.device_ptr(inp.stream()),
53            CudaStorageSlice::F32(inp) => inp.device_ptr(inp.stream()),
54            CudaStorageSlice::F64(inp) => inp.device_ptr(inp.stream()),
55            _ => candle_core::bail!("Unexpected dtype in asort"),
56        };
57        let src_ptr = src as *const c_void;
58        let (dst_ptr, dst_guard) = dst.device_ptr(dst.stream());
59        let dst_ptr = dst_ptr as *mut c_void;
60        let stream = dev.cuda_stream().cu_stream() as i64;
61        unsafe {
62            if self.asc {
63                match storage.dtype() {
64                    candle_core::DType::U8 => {
65                        ffi::asort_asc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
66                    }
67                    candle_core::DType::U32 => {
68                        ffi::asort_asc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
69                    }
70                    candle_core::DType::I64 => {
71                        ffi::asort_asc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
72                    }
73                    candle_core::DType::BF16 => {
74                        ffi::asort_asc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
75                    }
76                    candle_core::DType::F16 => {
77                        ffi::asort_asc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
78                    }
79                    candle_core::DType::F32 => {
80                        ffi::asort_asc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
81                    }
82                    candle_core::DType::F64 => {
83                        ffi::asort_asc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
84                    }
85                    _ => candle_core::bail!("Unexpected dtype in asort"),
86                }
87            } else {
88                match storage.dtype() {
89                    candle_core::DType::U8 => {
90                        ffi::asort_desc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
91                    }
92                    candle_core::DType::U32 => {
93                        ffi::asort_desc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
94                    }
95                    candle_core::DType::I64 => {
96                        ffi::asort_desc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
97                    }
98                    candle_core::DType::BF16 => {
99                        ffi::asort_desc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
100                    }
101                    candle_core::DType::F16 => {
102                        ffi::asort_desc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
103                    }
104                    candle_core::DType::F32 => {
105                        ffi::asort_desc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
106                    }
107                    candle_core::DType::F64 => {
108                        ffi::asort_desc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
109                    }
110                    _ => candle_core::bail!("Unexpected dtype in asort"),
111                }
112            }
113        }
114        drop(dst_guard);
115        let dst_ret = candle_core::cuda_backend::CudaStorage {
116            slice: CudaStorageSlice::U32(dst),
117            device: dev.clone(),
118        };
119        Ok((dst_ret, layout.shape().clone()))
120    }
121}
122
123#[allow(dead_code)]
124pub trait ArgSortOp {
125    fn arg_sort(&self, asc: bool) -> Result<Tensor>;
126    fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)>;
127}
128
129impl ArgSortOp for Tensor {
130    /// Returns the indices that sort the tensor along the last dimension.
131    ///
132    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
133    /// descending order. The sort is unstable so there is no guarantees on the final order when it
134    /// comes to ties.
135    fn arg_sort(&self, asc: bool) -> Result<Tensor> {
136        if !self.is_contiguous() {
137            return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
138        }
139        let last_dim = match self.dims().last() {
140            Some(last_dim) => *last_dim,
141            None => candle_core::bail!("empty last-dim in arg-sort"),
142        };
143        // No need for a backward pass for arg sort.
144        self.apply_op1_no_bwd(&ArgSort {
145            asc,
146            last_dim,
147            inplace: false,
148        })
149    }
150
151    /// Sorts the tensor along the last dimension, returns the sorted tensor together with the
152    /// sorted indexes.
153    ///
154    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
155    /// descending order. The sort is unstable so there is no guarantees on the final order when it
156    /// comes to ties.
157    fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)> {
158        if !self.is_contiguous() {
159            return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
160        }
161        let last_dim = match self.dims().last() {
162            Some(last_dim) => *last_dim,
163            None => candle_core::bail!("empty last-dim in arg-sort"),
164        };
165        let sorted = self.copy()?;
166
167        let asort = sorted.apply_op1_no_bwd(&ArgSort {
168            asc,
169            last_dim,
170            inplace: true,
171        })?;
172
173        Ok((sorted, asort))
174    }
175}
176
177#[allow(dead_code)]
178pub struct TopKOutput {
179    pub values: Tensor,
180    pub indices: Tensor,
181}
182
183pub trait TopKLastDimOp {
184    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
185    /// This expects a contiguous tensor.
186    /// Note: this implements torch.topk with sorted=True.
187    fn topk(&self, topk: usize) -> Result<TopKOutput>;
188
189    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
190    /// This expects a contiguous tensor.
191    /// Note: this implements torch.topk with sorted=False.
192    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
193}
194
195impl TopKLastDimOp for Tensor {
196    fn topk(&self, topk: usize) -> Result<TopKOutput> {
197        // Sorted descending
198        // #[cfg(feature = "cuda")]
199        // let (values, sorted_indices) = self.sort(false)?;
200        // #[cfg(not(feature = "cuda"))]
201        let (values, sorted_indices) = self.sort_last_dim(false)?;
202        let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
203        let topk_values = values.narrow(D::Minus1, 0, topk)?.contiguous()?;
204        Ok(TopKOutput {
205            values: topk_values,
206            indices: topk_indices,
207        })
208    }
209
210    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
211        // Sorted descending
212        let TopKOutput { values, indices } = self.topk(topk)?;
213        // Reorder the indices ascending
214        #[cfg(feature = "cuda")]
215        let reorder_indices = indices.arg_sort(true)?;
216        #[cfg(not(feature = "cuda"))]
217        let reorder_indices = indices.arg_sort_last_dim(true)?;
218        let topk_indices_unsorted = indices
219            .to_dtype(DType::F32)?
220            .gather(&reorder_indices, D::Minus1)?
221            .to_dtype(DType::U32)?;
222        let topk_values_unsorted = values.gather(&reorder_indices, D::Minus1)?;
223        Ok(TopKOutput {
224            values: topk_values_unsorted,
225            indices: topk_indices_unsorted,
226        })
227    }
228}
229
230pub trait RepeatInterleaveOp {
231    fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor>;
232    fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor>;
233}
234
235impl RepeatInterleaveOp for Tensor {
236    fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor> {
237        let dim = dim.to_index(self.shape(), "repeat_interleave")?;
238        let dim_elements = self.dim(dim)?;
239        // For metal
240        assert!(self.dtype().is_float());
241        #[allow(clippy::cast_possible_truncation)]
242        let indices = Tensor::new(
243            (0..dim_elements)
244                .flat_map(|i| vec![i as u32; repeats])
245                .collect::<Vec<_>>(),
246            self.device(),
247        )?;
248        self.index_select(&indices, dim)
249    }
250
251    fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor> {
252        let xs = self.flatten_all()?;
253        if repeats.len() != xs.dim(0)? {
254            candle_core::bail!(
255                "repeats ({}) must match flattened self length ({})",
256                repeats.len(),
257                xs.dim(0)?
258            );
259        }
260        #[allow(clippy::cast_possible_truncation)]
261        let indices = Tensor::new(
262            (0..xs.dim(0)?)
263                .flat_map(|i| vec![i as u32; repeats[i] as usize])
264                .collect::<Vec<_>>(),
265            xs.device(),
266        )?;
267        xs.index_select(&indices, 0)
268    }
269}
270
271pub trait SplitOp {
272    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
273}
274
275impl SplitOp for Tensor {
276    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
277        let dim = dim.to_index(self.shape(), "split")?;
278        let mut split_res = Vec::new();
279        let mut index = 0;
280        for split in splits {
281            split_res.push(self.narrow(dim, index, *split)?);
282            index += *split;
283        }
284        Ok(split_res)
285    }
286}
287
288#[allow(dead_code)]
289pub trait BincountOp {
290    fn bincount(&self, minlength: u32) -> Result<Vec<u32>>;
291}
292
293#[allow(dead_code)]
294fn bincount(values: &[u32], minlength: u32) -> Vec<u32> {
295    // let max_val = values.iter().max().copied().unwrap_or(0);
296    // let result_len = (max_val + 1).max(minlength);
297    // values.iter().fold(
298    //     // Start with a histogram vector of zeros.
299    //     vec![0u32; result_len as usize],
300    //     // For each value, update the histogram.
301    //     |mut histogram, &value| {
302    //         histogram[value as usize] += 1;
303    //         histogram
304    //     },
305    // )
306
307    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
308
309    // Early return if there are no values.
310    if values.is_empty() {
311        return vec![0u32; minlength as usize];
312    }
313
314    // Compute the maximum value in parallel.
315    // SAFETY: we know `values` is nonempty.
316    let max_val = *values.par_iter().max().unwrap();
317
318    // The histogram length must cover all observed values as well as `minlength`.
319    let result_len = (max_val + 1).max(minlength) as usize;
320
321    // Build per-thread histograms in parallel.
322    // We use unsafe indexing to eliminate bounds checks in the inner loop.
323    values
324        .par_iter()
325        .fold(
326            || vec![0u32; result_len],
327            |mut local_hist, &v| {
328                // SAFETY: v is guaranteed to be <= max_val, so it is in bounds.
329                unsafe {
330                    *local_hist.get_unchecked_mut(v as usize) += 1;
331                }
332                local_hist
333            },
334        )
335        // Merge the per-thread histograms in parallel.
336        .reduce(
337            || vec![0u32; result_len],
338            |mut global_hist, local_hist| {
339                for i in 0..result_len {
340                    // SAFETY: we know local histogram is at least result_len, as is global_hist
341                    unsafe {
342                        *global_hist.get_unchecked_mut(i) += local_hist.get_unchecked(i);
343                    }
344                }
345                global_hist
346            },
347        )
348}
349
350#[allow(dead_code)]
351impl BincountOp for Tensor {
352    fn bincount(&self, minlength: u32) -> Result<Vec<u32>> {
353        let values = self.to_vec1::<u32>()?;
354
355        Ok(bincount(&values, minlength))
356    }
357}
358
359// https://github.com/mokeyish/candle-ext/blob/ca4547c803469bd51c00ce5eda2f18dd249c8f10/src/triangular.rs#L21
360pub fn apply_triangular(xs: &Tensor, diagonal: isize, upper: bool) -> Result<Tensor> {
361    let device = xs.device();
362    let (l, s) = xs.dims2()?;
363    let mut xs_tri = vec![];
364    for i in 0..l as isize {
365        for j in 0..s as isize {
366            let cond = if upper {
367                i + diagonal > j
368            } else {
369                i + diagonal < j
370            };
371            xs_tri.push(if cond { 0u8 } else { 1u8 });
372        }
373    }
374    xs * Tensor::from_vec(xs_tri, (l, s), device)?.to_dtype(xs.dtype())?
375}
376
377/// Elementwise multiply and activation. The following activations are supported:
378/// - `gelu`
379/// - `silu`
380/// - `relu`
381///
382/// This is equivalent to:
383/// `act(a) * b`
384pub fn mul_and_act(a: &Tensor, b: &Tensor, act: Activation) -> Result<Tensor> {
385    a.apply(&act)? * b
386}
387
388mod tests {
389    #[test]
390    fn test_topk() {
391        use crate::ops::{TopKLastDimOp, TopKOutput};
392        use candle_core::Tensor;
393        let device = candle_core::Device::Cpu;
394        //  [[1, 3, 5],
395        //   [2, 4, 6]]
396        let x = Tensor::arange(1f32, 7f32, &device)
397            .unwrap()
398            .reshape((3, 2))
399            .unwrap()
400            .t()
401            .unwrap()
402            .contiguous()
403            .unwrap();
404        let TopKOutput { values, indices } = x.topk(2).unwrap();
405        assert_eq!(
406            x.to_vec2::<f32>().unwrap(),
407            vec![vec![1f32, 3f32, 5f32], vec![2f32, 4f32, 6f32]]
408        );
409        assert_eq!(
410            values.to_vec2::<f32>().unwrap(),
411            vec![vec![5f32, 3f32], vec![6f32, 4f32]]
412        );
413        assert_eq!(
414            indices.to_vec2::<u32>().unwrap(),
415            vec![vec![2u32, 1u32], vec![2u32, 1u32]]
416        );
417    }
418
419    #[test]
420    fn test_repeat_interleave() -> candle_core::Result<()> {
421        use crate::ops::RepeatInterleaveOp;
422        use candle_core::{Device, Tensor};
423
424        let input = Tensor::new(
425            vec![vec![vec![1f32, 2., 3.], vec![4f32, 5., 6.]]],
426            &Device::Cpu,
427        )?;
428
429        let repeat_interleaved = input.repeat_interleave(2, 2)?;
430        assert_eq!(
431            repeat_interleaved.to_vec3::<f32>()?,
432            vec![vec![
433                vec![1., 1., 2., 2., 3., 3.],
434                vec![4., 4., 5., 5., 6., 6.]
435            ]]
436        );
437
438        Ok(())
439    }
440
441    #[test]
442    fn test_repeat_interleave_flat() -> candle_core::Result<()> {
443        use crate::ops::RepeatInterleaveOp;
444        use candle_core::{Device, Tensor};
445
446        let input = Tensor::new(vec![1., 2., 3., 4.], &Device::Cpu)?;
447
448        let repeat_interleaved = input.repeat_interleave_flat(vec![1u32, 2u32, 3u32, 4u32])?;
449        assert_eq!(
450            repeat_interleaved.to_vec1::<f64>()?,
451            vec![1., 2., 2., 3., 3., 3., 4., 4., 4., 4.]
452        );
453
454        Ok(())
455    }
456}