mistralrs_core/cuda/
ffi.rs

1use std::ffi::c_void;
2
3#[cfg(feature = "cuda")]
4type FfiCudaStream = candle_core::cuda::cudarc::driver::sys::CUstream;
5#[cfg(not(feature = "cuda"))]
6type FfiCudaStream = *const std::ffi::c_void;
7
8#[allow(dead_code)]
9extern "C" {
10    pub(crate) fn count_nonzero_bf16(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32;
11    pub(crate) fn count_nonzero_f16(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32;
12    pub(crate) fn count_nonzero_f32(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32;
13    pub(crate) fn count_nonzero_f64(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32;
14    pub(crate) fn count_nonzero_u8(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32;
15    pub(crate) fn count_nonzero_u32(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32;
16    pub(crate) fn count_nonzero_i16(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32;
17    pub(crate) fn count_nonzero_i64(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32;
18    pub(crate) fn count_nonzero_i32(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32;
19    pub(crate) fn nonzero_bf16(
20        d_in: *const c_void,
21        N: u32,
22        num_nonzero: u32,
23        dims: *const c_void,
24        num_dims: u32,
25        d_out: *mut c_void,
26        stream: FfiCudaStream,
27    );
28    pub(crate) fn nonzero_f16(
29        d_in: *const c_void,
30        N: u32,
31        num_nonzero: u32,
32        dims: *const c_void,
33        num_dims: u32,
34        d_out: *mut c_void,
35        stream: FfiCudaStream,
36    );
37    pub(crate) fn nonzero_f32(
38        d_in: *const c_void,
39        N: u32,
40        num_nonzero: u32,
41        dims: *const c_void,
42        num_dims: u32,
43        d_out: *mut c_void,
44        stream: FfiCudaStream,
45    );
46    pub(crate) fn nonzero_f64(
47        d_in: *const c_void,
48        N: u32,
49        num_nonzero: u32,
50        dims: *const c_void,
51        num_dims: u32,
52        d_out: *mut c_void,
53        stream: FfiCudaStream,
54    );
55    pub(crate) fn nonzero_u8(
56        d_in: *const c_void,
57        N: u32,
58        num_nonzero: u32,
59        dims: *const c_void,
60        num_dims: u32,
61        d_out: *mut c_void,
62        stream: FfiCudaStream,
63    );
64    pub(crate) fn nonzero_u32(
65        d_in: *const c_void,
66        N: u32,
67        num_nonzero: u32,
68        dims: *const c_void,
69        num_dims: u32,
70        d_out: *mut c_void,
71        stream: FfiCudaStream,
72    );
73    pub(crate) fn nonzero_i64(
74        d_in: *const c_void,
75        N: u32,
76        num_nonzero: u32,
77        dims: *const c_void,
78        num_dims: u32,
79        d_out: *mut c_void,
80        stream: FfiCudaStream,
81    );
82    pub(crate) fn nonzero_i16(
83        d_in: *const c_void,
84        N: u32,
85        num_nonzero: u32,
86        dims: *const c_void,
87        num_dims: u32,
88        d_out: *mut c_void,
89        stream: FfiCudaStream,
90    );
91    pub(crate) fn nonzero_i32(
92        d_in: *const c_void,
93        N: u32,
94        num_nonzero: u32,
95        dims: *const c_void,
96        num_dims: u32,
97        d_out: *mut c_void,
98        stream: FfiCudaStream,
99    );
100
101    pub(crate) fn bitwise_and_u8(
102        d_in1: *const c_void,
103        d_in2: *const c_void,
104        d_out: *mut c_void,
105        N: u32,
106    );
107    pub(crate) fn bitwise_and_u32(
108        d_in1: *const c_void,
109        d_in2: *const c_void,
110        d_out: *mut c_void,
111        N: u32,
112    );
113    pub(crate) fn bitwise_and_i64(
114        d_in1: *const c_void,
115        d_in2: *const c_void,
116        d_out: *mut c_void,
117        N: u32,
118    );
119    pub(crate) fn bitwise_and_i32(
120        d_in1: *const c_void,
121        d_in2: *const c_void,
122        d_out: *mut c_void,
123        N: u32,
124    );
125    pub(crate) fn bitwise_or_u8(
126        d_in1: *const c_void,
127        d_in2: *const c_void,
128        d_out: *mut c_void,
129        N: u32,
130    );
131    pub(crate) fn bitwise_or_u32(
132        d_in1: *const c_void,
133        d_in2: *const c_void,
134        d_out: *mut c_void,
135        N: u32,
136    );
137    pub(crate) fn bitwise_or_i64(
138        d_in1: *const c_void,
139        d_in2: *const c_void,
140        d_out: *mut c_void,
141        N: u32,
142    );
143    pub(crate) fn bitwise_or_i32(
144        d_in1: *const c_void,
145        d_in2: *const c_void,
146        d_out: *mut c_void,
147        N: u32,
148    );
149    pub(crate) fn bitwise_xor_u8(
150        d_in1: *const c_void,
151        d_in2: *const c_void,
152        d_out: *mut c_void,
153        N: u32,
154    );
155    pub(crate) fn bitwise_xor_u32(
156        d_in1: *const c_void,
157        d_in2: *const c_void,
158        d_out: *mut c_void,
159        N: u32,
160    );
161    pub(crate) fn bitwise_xor_i64(
162        d_in1: *const c_void,
163        d_in2: *const c_void,
164        d_out: *mut c_void,
165        N: u32,
166    );
167    pub(crate) fn bitwise_xor_i32(
168        d_in1: *const c_void,
169        d_in2: *const c_void,
170        d_out: *mut c_void,
171        N: u32,
172    );
173    // Linked to in mistralrs-quant
174    pub(crate) fn leftshift_u8(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32);
175    pub(crate) fn leftshift_u32(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32);
176    pub(crate) fn leftshift_i64(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32);
177    pub(crate) fn leftshift_i32(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32);
178
179    pub(crate) fn asort_asc_f32(
180        x: *const c_void,
181        dst: *mut c_void,
182        nrows: i32,
183        ncols: i32,
184        inplace: bool,
185        stream: i64,
186    );
187    pub(crate) fn asort_asc_f16(
188        x: *const c_void,
189        dst: *mut c_void,
190        nrows: i32,
191        ncols: i32,
192        inplace: bool,
193        stream: i64,
194    );
195    pub(crate) fn asort_asc_bf16(
196        x: *const c_void,
197        dst: *const c_void,
198        nrows: i32,
199        ncols: i32,
200        inplace: bool,
201        stream: i64,
202    );
203    pub(crate) fn asort_asc_f64(
204        x: *const c_void,
205        dst: *mut c_void,
206        nrows: i32,
207        ncols: i32,
208        inplace: bool,
209        stream: i64,
210    );
211    pub(crate) fn asort_asc_u8(
212        x: *const c_void,
213        dst: *mut c_void,
214        nrows: i32,
215        ncols: i32,
216        inplace: bool,
217        stream: i64,
218    );
219    pub(crate) fn asort_asc_u32(
220        x: *const c_void,
221        dst: *mut c_void,
222        nrows: i32,
223        ncols: i32,
224        inplace: bool,
225        stream: i64,
226    );
227    pub(crate) fn asort_asc_i64(
228        x: *const c_void,
229        dst: *mut c_void,
230        nrows: i32,
231        ncols: i32,
232        inplace: bool,
233        stream: i64,
234    );
235    pub(crate) fn asort_desc_f32(
236        x: *const c_void,
237        dst: *mut c_void,
238        nrows: i32,
239        ncols: i32,
240        inplace: bool,
241        stream: i64,
242    );
243    pub(crate) fn asort_desc_f16(
244        x: *const c_void,
245        dst: *mut c_void,
246        nrows: i32,
247        ncols: i32,
248        inplace: bool,
249        stream: i64,
250    );
251    pub(crate) fn asort_desc_bf16(
252        x: *const c_void,
253        dst: *mut c_void,
254        nrows: i32,
255        ncols: i32,
256        inplace: bool,
257        stream: i64,
258    );
259    pub(crate) fn asort_desc_f64(
260        x: *const c_void,
261        dst: *mut c_void,
262        nrows: i32,
263        ncols: i32,
264        inplace: bool,
265        stream: i64,
266    );
267    pub(crate) fn asort_desc_u8(
268        x: *const c_void,
269        dst: *mut c_void,
270        nrows: i32,
271        ncols: i32,
272        inplace: bool,
273        stream: i64,
274    );
275    pub(crate) fn asort_desc_u32(
276        x: *const c_void,
277        dst: *mut c_void,
278        nrows: i32,
279        ncols: i32,
280        inplace: bool,
281        stream: i64,
282    );
283    pub(crate) fn asort_desc_i64(
284        x: *const c_void,
285        dst: *mut c_void,
286        nrows: i32,
287        ncols: i32,
288        inplace: bool,
289        stream: i64,
290    );
291}