mistralrs_core/
ops.rs

1use candle_core::{shape::Dim, DType, Result, Tensor, D};
2
3#[cfg(feature = "cuda")]
4use crate::cuda::ffi;
5use crate::layers::Activation;
6
7#[allow(dead_code)]
8#[derive(Debug, Clone)]
9struct ArgSort {
10    asc: bool,
11    last_dim: usize,
12    inplace: bool,
13}
14
15impl candle_core::CustomOp1 for ArgSort {
16    fn name(&self) -> &'static str {
17        "argsort"
18    }
19
20    fn cpu_fwd(
21        &self,
22        _: &candle_core::CpuStorage,
23        _: &candle_core::Layout,
24    ) -> Result<(candle_core::CpuStorage, candle_core::Shape)> {
25        panic!("not implemented!")
26    }
27
28    #[allow(clippy::cast_possible_truncation)]
29    #[cfg(feature = "cuda")]
30    fn cuda_fwd(
31        &self,
32        storage: &candle_core::CudaStorage,
33        layout: &candle_core::Layout,
34    ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> {
35        use candle_core::backend::BackendStorage;
36        use candle_core::cuda_backend::cudarc::driver::DevicePtr;
37        use candle_core::cuda_backend::CudaStorageSlice;
38
39        let dev = storage.device();
40        let elem_count = layout.shape().elem_count();
41        let ncols = self.last_dim as i32;
42        let nrows = elem_count as i32 / ncols;
43        let dst = unsafe { dev.alloc::<u32>(elem_count) }?;
44
45        use std::ffi::c_void;
46
47        let (src, _src_guard) = match &storage.slice {
48            CudaStorageSlice::U8(inp) => inp.device_ptr(inp.stream()),
49            CudaStorageSlice::U32(inp) => inp.device_ptr(inp.stream()),
50            CudaStorageSlice::I64(inp) => inp.device_ptr(inp.stream()),
51            CudaStorageSlice::BF16(inp) => inp.device_ptr(inp.stream()),
52            CudaStorageSlice::F16(inp) => inp.device_ptr(inp.stream()),
53            CudaStorageSlice::F32(inp) => inp.device_ptr(inp.stream()),
54            CudaStorageSlice::F64(inp) => inp.device_ptr(inp.stream()),
55            _ => candle_core::bail!("Unexpected dtype in asort"),
56        };
57        let src_ptr = src as *const c_void;
58        let (dst_ptr, dst_guard) = dst.device_ptr(dst.stream());
59        let dst_ptr = dst_ptr as *mut c_void;
60        let stream = dev.cuda_stream().cu_stream() as i64;
61        unsafe {
62            if self.asc {
63                match storage.dtype() {
64                    candle_core::DType::U8 => {
65                        ffi::asort_asc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
66                    }
67                    candle_core::DType::U32 => {
68                        ffi::asort_asc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
69                    }
70                    candle_core::DType::I64 => {
71                        ffi::asort_asc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
72                    }
73                    candle_core::DType::BF16 => {
74                        ffi::asort_asc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
75                    }
76                    candle_core::DType::F16 => {
77                        ffi::asort_asc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
78                    }
79                    candle_core::DType::F32 => {
80                        ffi::asort_asc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
81                    }
82                    candle_core::DType::F64 => {
83                        ffi::asort_asc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
84                    }
85                    _ => candle_core::bail!("Unexpected dtype in asort"),
86                }
87            } else {
88                match storage.dtype() {
89                    candle_core::DType::U8 => {
90                        ffi::asort_desc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
91                    }
92                    candle_core::DType::U32 => {
93                        ffi::asort_desc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
94                    }
95                    candle_core::DType::I64 => {
96                        ffi::asort_desc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
97                    }
98                    candle_core::DType::BF16 => {
99                        ffi::asort_desc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
100                    }
101                    candle_core::DType::F16 => {
102                        ffi::asort_desc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
103                    }
104                    candle_core::DType::F32 => {
105                        ffi::asort_desc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
106                    }
107                    candle_core::DType::F64 => {
108                        ffi::asort_desc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream)
109                    }
110                    _ => candle_core::bail!("Unexpected dtype in asort"),
111                }
112            }
113        }
114        drop(dst_guard);
115        let dst_ret = candle_core::cuda_backend::CudaStorage {
116            slice: CudaStorageSlice::U32(dst),
117            device: dev.clone(),
118        };
119        Ok((dst_ret, layout.shape().clone()))
120    }
121}
122
123#[allow(dead_code)]
124pub trait ArgSortOp {
125    fn arg_sort(&self, asc: bool) -> Result<Tensor>;
126    fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)>;
127}
128
129impl ArgSortOp for Tensor {
130    /// Returns the indices that sort the tensor along the last dimension.
131    ///
132    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
133    /// descending order. The sort is unstable so there is no guarantees on the final order when it
134    /// comes to ties.
135    fn arg_sort(&self, asc: bool) -> Result<Tensor> {
136        if !self.is_contiguous() {
137            return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
138        }
139        let last_dim = match self.dims().last() {
140            Some(last_dim) => *last_dim,
141            None => candle_core::bail!("empty last-dim in arg-sort"),
142        };
143        // No need for a backward pass for arg sort.
144        self.apply_op1_no_bwd(&ArgSort {
145            asc,
146            last_dim,
147            inplace: false,
148        })
149    }
150
151    /// Sorts the tensor along the last dimension, returns the sorted tensor together with the
152    /// sorted indexes.
153    ///
154    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
155    /// descending order. The sort is unstable so there is no guarantees on the final order when it
156    /// comes to ties.
157    fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)> {
158        if !self.is_contiguous() {
159            return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" });
160        }
161        let last_dim = match self.dims().last() {
162            Some(last_dim) => *last_dim,
163            None => candle_core::bail!("empty last-dim in arg-sort"),
164        };
165        let sorted = self.copy()?;
166
167        let asort = sorted.apply_op1_no_bwd(&ArgSort {
168            asc,
169            last_dim,
170            inplace: true,
171        })?;
172
173        Ok((sorted, asort))
174    }
175}
176
177#[allow(dead_code)]
178pub struct TopKOutput {
179    pub values: Tensor,
180    pub indices: Tensor,
181}
182
183pub trait TopKLastDimOp {
184    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
185    /// This expects a contiguous tensor.
186    /// Note: this implements torch.topk with sorted=True.
187    fn topk(&self, topk: usize) -> Result<TopKOutput>;
188
189    /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self.
190    /// This expects a contiguous tensor.
191    /// Note: this implements torch.topk with sorted=False.
192    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
193}
194
195impl TopKLastDimOp for Tensor {
196    fn topk(&self, topk: usize) -> Result<TopKOutput> {
197        // Sorted descending
198        let (values, sorted_indices) = if self.device().is_cuda() {
199            self.sort(false)?
200        } else {
201            self.sort_last_dim(false)?
202        };
203        let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
204        let topk_values = values.narrow(D::Minus1, 0, topk)?.contiguous()?;
205        Ok(TopKOutput {
206            values: topk_values,
207            indices: topk_indices,
208        })
209    }
210
211    fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
212        // Sorted descending
213        let TopKOutput { values, indices } = self.topk(topk)?;
214        // Reorder the indices ascending
215        #[cfg(feature = "cuda")]
216        let reorder_indices = indices.arg_sort(true)?;
217        #[cfg(not(feature = "cuda"))]
218        let reorder_indices = indices.arg_sort_last_dim(true)?;
219        let topk_indices_unsorted = indices
220            .to_dtype(DType::F32)?
221            .gather(&reorder_indices, D::Minus1)?
222            .to_dtype(DType::U32)?;
223        let topk_values_unsorted = values.gather(&reorder_indices, D::Minus1)?;
224        Ok(TopKOutput {
225            values: topk_values_unsorted,
226            indices: topk_indices_unsorted,
227        })
228    }
229}
230
231pub trait RepeatInterleaveOp {
232    fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor>;
233    fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor>;
234}
235
236impl RepeatInterleaveOp for Tensor {
237    fn repeat_interleave<D: Dim>(&self, repeats: usize, dim: D) -> Result<Tensor> {
238        let dim = dim.to_index(self.shape(), "repeat_interleave")?;
239        let dim_elements = self.dim(dim)?;
240        // For metal
241        assert!(self.dtype().is_float());
242        #[allow(clippy::cast_possible_truncation)]
243        let indices = Tensor::new(
244            (0..dim_elements)
245                .flat_map(|i| vec![i as u32; repeats])
246                .collect::<Vec<_>>(),
247            self.device(),
248        )?;
249        self.index_select(&indices, dim)
250    }
251
252    fn repeat_interleave_flat(&self, repeats: Vec<u32>) -> Result<Tensor> {
253        let xs = self.flatten_all()?;
254        if repeats.len() != xs.dim(0)? {
255            candle_core::bail!(
256                "repeats ({}) must match flattened self length ({})",
257                repeats.len(),
258                xs.dim(0)?
259            );
260        }
261        #[allow(clippy::cast_possible_truncation)]
262        let indices = Tensor::new(
263            (0..xs.dim(0)?)
264                .flat_map(|i| vec![i as u32; repeats[i] as usize])
265                .collect::<Vec<_>>(),
266            xs.device(),
267        )?;
268        xs.index_select(&indices, 0)
269    }
270}
271
272pub trait SplitOp {
273    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>>;
274}
275
276impl SplitOp for Tensor {
277    fn split<D: Dim>(&self, splits: &[usize], dim: D) -> Result<Vec<Tensor>> {
278        let dim = dim.to_index(self.shape(), "split")?;
279        let mut split_res = Vec::new();
280        let mut index = 0;
281        for split in splits {
282            split_res.push(self.narrow(dim, index, *split)?);
283            index += *split;
284        }
285        Ok(split_res)
286    }
287}
288
289#[allow(dead_code)]
290pub trait BincountOp {
291    fn bincount(&self, minlength: u32) -> Result<Vec<u32>>;
292}
293
294#[allow(dead_code)]
295fn bincount(values: &[u32], minlength: u32) -> Vec<u32> {
296    // let max_val = values.iter().max().copied().unwrap_or(0);
297    // let result_len = (max_val + 1).max(minlength);
298    // values.iter().fold(
299    //     // Start with a histogram vector of zeros.
300    //     vec![0u32; result_len as usize],
301    //     // For each value, update the histogram.
302    //     |mut histogram, &value| {
303    //         histogram[value as usize] += 1;
304    //         histogram
305    //     },
306    // )
307
308    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
309
310    // Early return if there are no values.
311    if values.is_empty() {
312        return vec![0u32; minlength as usize];
313    }
314
315    // Compute the maximum value in parallel.
316    // SAFETY: We just checked that values is nonempty above, so max() will return Some.
317    // Using expect() for clearer error message if this invariant is somehow violated.
318    let max_val = *values
319        .par_iter()
320        .max()
321        .expect("values should be non-empty after empty check");
322
323    // The histogram length must cover all observed values as well as `minlength`.
324    let result_len = (max_val + 1).max(minlength) as usize;
325
326    // Build per-thread histograms in parallel.
327    // We use unsafe indexing to eliminate bounds checks in the inner loop.
328    values
329        .par_iter()
330        .fold(
331            || vec![0u32; result_len],
332            |mut local_hist, &v| {
333                // SAFETY: v is guaranteed to be <= max_val, so it is in bounds.
334                unsafe {
335                    *local_hist.get_unchecked_mut(v as usize) += 1;
336                }
337                local_hist
338            },
339        )
340        // Merge the per-thread histograms in parallel.
341        .reduce(
342            || vec![0u32; result_len],
343            |mut global_hist, local_hist| {
344                for i in 0..result_len {
345                    // SAFETY: we know local histogram is at least result_len, as is global_hist
346                    unsafe {
347                        *global_hist.get_unchecked_mut(i) += local_hist.get_unchecked(i);
348                    }
349                }
350                global_hist
351            },
352        )
353}
354
355#[allow(dead_code)]
356impl BincountOp for Tensor {
357    fn bincount(&self, minlength: u32) -> Result<Vec<u32>> {
358        let values = self.to_vec1::<u32>()?;
359
360        Ok(bincount(&values, minlength))
361    }
362}
363
364// https://github.com/mokeyish/candle-ext/blob/ca4547c803469bd51c00ce5eda2f18dd249c8f10/src/triangular.rs#L21
365pub fn apply_triangular(xs: &Tensor, diagonal: isize, upper: bool) -> Result<Tensor> {
366    let device = xs.device();
367    let (l, s) = xs.dims2()?;
368    let mut xs_tri = vec![];
369    for i in 0..l as isize {
370        for j in 0..s as isize {
371            let cond = if upper {
372                i + diagonal > j
373            } else {
374                i + diagonal < j
375            };
376            xs_tri.push(if cond { 0u8 } else { 1u8 });
377        }
378    }
379    xs * Tensor::from_vec(xs_tri, (l, s), device)?.to_dtype(xs.dtype())?
380}
381
382/// Elementwise multiply and activation. The following activations are supported:
383/// - `gelu`
384/// - `silu`
385/// - `relu`
386///
387/// This is equivalent to:
388/// `act(a) * b`
389pub fn mul_and_act(a: &Tensor, b: &Tensor, act: Activation) -> Result<Tensor> {
390    a.apply(&act)? * b
391}
392
393mod tests {
394    #[test]
395    fn test_topk() {
396        use crate::ops::{TopKLastDimOp, TopKOutput};
397        use candle_core::Tensor;
398        let device = candle_core::Device::Cpu;
399        //  [[1, 3, 5],
400        //   [2, 4, 6]]
401        let x = Tensor::arange(1f32, 7f32, &device)
402            .unwrap()
403            .reshape((3, 2))
404            .unwrap()
405            .t()
406            .unwrap()
407            .contiguous()
408            .unwrap();
409        let TopKOutput { values, indices } = x.topk(2).unwrap();
410        assert_eq!(
411            x.to_vec2::<f32>().unwrap(),
412            vec![vec![1f32, 3f32, 5f32], vec![2f32, 4f32, 6f32]]
413        );
414        assert_eq!(
415            values.to_vec2::<f32>().unwrap(),
416            vec![vec![5f32, 3f32], vec![6f32, 4f32]]
417        );
418        assert_eq!(
419            indices.to_vec2::<u32>().unwrap(),
420            vec![vec![2u32, 1u32], vec![2u32, 1u32]]
421        );
422    }
423
424    #[test]
425    fn test_repeat_interleave() -> candle_core::Result<()> {
426        use crate::ops::RepeatInterleaveOp;
427        use candle_core::{Device, Tensor};
428
429        let input = Tensor::new(
430            vec![vec![vec![1f32, 2., 3.], vec![4f32, 5., 6.]]],
431            &Device::Cpu,
432        )?;
433
434        let repeat_interleaved = input.repeat_interleave(2, 2)?;
435        assert_eq!(
436            repeat_interleaved.to_vec3::<f32>()?,
437            vec![vec![
438                vec![1., 1., 2., 2., 3., 3.],
439                vec![4., 4., 5., 5., 6., 6.]
440            ]]
441        );
442
443        Ok(())
444    }
445
446    #[test]
447    fn test_repeat_interleave_flat() -> candle_core::Result<()> {
448        use crate::ops::RepeatInterleaveOp;
449        use candle_core::{Device, Tensor};
450
451        let input = Tensor::new(vec![1., 2., 3., 4.], &Device::Cpu)?;
452
453        let repeat_interleaved = input.repeat_interleave_flat(vec![1u32, 2u32, 3u32, 4u32])?;
454        assert_eq!(
455            repeat_interleaved.to_vec1::<f64>()?,
456            vec![1., 2., 2., 3., 3., 3., 4., 4., 4., 4.]
457        );
458
459        Ok(())
460    }
461}