1use candle_core::{
2 backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3 Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8 fmt::Display,
9 ops::{BitAnd, BitOr, BitXor, Not, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::{ffi, slice_ptr};
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19#[cfg(feature = "metal")]
20use crate::metal_kernels::SortScratchCache; #[cfg(feature = "metal")]
22use std::sync::OnceLock;
23
24#[cfg(feature = "metal")]
25static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
26
27struct Leftshift(usize);
28
29impl Leftshift {
30 fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
31 let offset = T::from_f64(self.0 as f64);
32 vs.into_par_iter().map(|v| *v << offset).collect()
33 }
34}
35
36impl CustomOp1 for Leftshift {
37 fn name(&self) -> &'static str {
38 "left"
39 }
40
41 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
42 match s1 {
43 CpuStorage::U8(vs1) => {
44 let vs1 = match l1.contiguous_offsets() {
45 Some((a, b)) => &vs1[a..b],
46 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
47 };
48 let result = self.leftshift(vs1);
49 let result = CpuStorage::U8(result);
50 Ok((result, l1.shape().clone()))
51 }
52 CpuStorage::I16(vs1) => {
53 let vs1 = match l1.contiguous_offsets() {
54 Some((a, b)) => &vs1[a..b],
55 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
56 };
57 let result = self.leftshift(vs1);
58 let result = CpuStorage::I16(result);
59 Ok((result, l1.shape().clone()))
60 }
61 CpuStorage::U32(vs1) => {
62 let vs1 = match l1.contiguous_offsets() {
63 Some((a, b)) => &vs1[a..b],
64 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
65 };
66 let result = self.leftshift(vs1);
67 let result = CpuStorage::U32(result);
68 Ok((result, l1.shape().clone()))
69 }
70 CpuStorage::I64(vs1) => {
71 let vs1 = match l1.contiguous_offsets() {
72 Some((a, b)) => &vs1[a..b],
73 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
74 };
75 let result = self.leftshift(vs1);
76 let result = CpuStorage::I64(result);
77 Ok((result, l1.shape().clone()))
78 }
79 CpuStorage::I32(vs1) => {
80 let vs1 = match l1.contiguous_offsets() {
81 Some((a, b)) => &vs1[a..b],
82 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
83 };
84 let result = self.leftshift(vs1);
85 let result = CpuStorage::I32(result);
86 Ok((result, l1.shape().clone()))
87 }
88 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "leftshift")),
89 }
90 }
91
92 #[cfg(feature = "cuda")]
93 fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
94 if !l1.is_contiguous() {
95 candle_core::bail!("Input tensor s1 must be contiguous");
96 }
97 let dev = s1.device().clone();
98 let (d_in1_ptr, _d_guard, elem_count) = match s1.dtype() {
99 DType::U8 => {
100 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
101 let elem_count = l1.shape().elem_count();
102 (d_in1 as *const c_void, d_in1_guard, elem_count)
103 }
104 DType::I32 => {
105 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
106 let elem_count = l1.shape().elem_count();
107 (d_in1 as *const c_void, d_in1_guard, elem_count)
108 }
109 other => {
110 return Err(Error::UnsupportedDTypeForOp(other, "leftshift"));
111 }
112 };
113 let dst = match s1.dtype() {
114 DType::U8 => {
115 let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
116 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
117 unsafe {
118 ffi::leftshift_u8(
119 d_in1_ptr,
120 d_out_ptr as *mut std::ffi::c_void,
121 u32::try_from(elem_count)?,
122 self.0 as i32,
123 )
124 };
125 drop(d_out_guard);
126 CudaStorage::wrap_cuda_slice(d_out, dev)
127 }
128 DType::I32 => {
129 let d_out = unsafe { dev.alloc::<i32>(elem_count) }?;
130 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
131 unsafe {
132 ffi::leftshift_i32(
133 d_in1_ptr,
134 d_out_ptr as *mut std::ffi::c_void,
135 u32::try_from(elem_count)?,
136 self.0 as i32,
137 )
138 };
139 drop(d_out_guard);
140 CudaStorage::wrap_cuda_slice(d_out, dev)
141 }
142 _ => unreachable!(),
143 };
144 Ok((dst, l1.shape().clone()))
145 }
146
147 #[cfg(feature = "metal")]
148 fn metal_fwd(
149 &self,
150 s1: &candle_core::MetalStorage,
151 l1: &Layout,
152 ) -> Result<(candle_core::MetalStorage, Shape)> {
153 if !l1.is_contiguous() {
154 candle_core::bail!("Input tensor s1 must be contiguous");
155 }
156
157 let encoder = s1.device().command_encoder()?;
158 encoder.set_label("bitwise-leftshift");
159
160 let device = s1.device();
161
162 let out_shape = l1.shape().clone();
163
164 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
165
166 crate::metal_kernels::call_bitwise_leftshift(
167 device.device(),
168 &encoder,
169 &crate::metal_kernels::Kernels::new(),
170 s1.dtype(),
171 s1.buffer(),
172 l1.start_offset(),
173 self.0 as u32,
174 out_shape.elem_count(),
175 &output,
176 )
177 .map_err(candle_core::Error::wrap)?;
178
179 let newstorage = candle_core::MetalStorage::new(
180 output,
181 device.clone(),
182 out_shape.elem_count(),
183 s1.dtype(),
184 );
185 Ok((newstorage, out_shape))
186 }
187}
188
189#[allow(dead_code)]
190pub trait LeftshiftOp {
191 fn leftshift(&self, n: usize) -> Result<Tensor>;
192}
193
194impl LeftshiftOp for Tensor {
195 fn leftshift(&self, n: usize) -> Result<Tensor> {
196 self.apply_op1_no_bwd(&Leftshift(n))
197 }
198}
199
200pub enum BitWiseBinaryOpEnum {
201 And,
202 Or,
203 Xor,
204}
205
206impl Display for BitWiseBinaryOpEnum {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 match self {
209 BitWiseBinaryOpEnum::And => write!(f, "And"),
210 BitWiseBinaryOpEnum::Or => write!(f, "Or"),
211 BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
212 }
213 }
214}
215
216pub enum BitWiseUnaryOpEnum {
217 Not,
218}
219
220impl Display for BitWiseUnaryOpEnum {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 match self {
223 BitWiseUnaryOpEnum::Not => write!(f, "Not"),
224 }
225 }
226}
227
228struct BitWise {
229 pub op: BitWiseBinaryOpEnum,
230}
231
232impl BitWise {
233 pub fn new(op: BitWiseBinaryOpEnum) -> Self {
234 Self { op }
235 }
236
237 fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
238 &self,
239 vs1: &[T],
240 vs2: &[T],
241 ) -> Vec<T> {
242 vs1.into_par_iter()
243 .zip_eq(vs2)
244 .map(|(v1, v2)| match self.op {
245 BitWiseBinaryOpEnum::And => *v1 & *v2,
246 BitWiseBinaryOpEnum::Or => *v1 | *v2,
247 BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
248 })
249 .collect()
250 }
251}
252
253impl CustomOp2 for BitWise {
254 fn name(&self) -> &'static str {
255 "bitwise"
256 }
257
258 fn cpu_fwd(
259 &self,
260 s1: &CpuStorage,
261 l1: &Layout,
262 s2: &CpuStorage,
263 l2: &Layout,
264 ) -> Result<(CpuStorage, Shape)> {
265 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
266 return Err(Error::ShapeMismatchBinaryOp {
267 lhs: l1.shape().clone(),
268 rhs: l2.shape().clone(),
269 op: "bitwise-op",
270 });
271 }
272 if s1.dtype() != s2.dtype() {
273 return Err(Error::DTypeMismatchBinaryOp {
274 lhs: s1.dtype(),
275 rhs: s2.dtype(),
276 op: "bitwise-op",
277 });
278 }
279 if !l1.is_contiguous() {
280 candle_core::bail!("Input tensor s1 must be contiguous");
281 }
282 if !l2.is_contiguous() {
283 candle_core::bail!("Input tensor s2 must be contiguous");
284 }
285
286 match s1 {
287 CpuStorage::U8(vs1) => {
288 let vs2 = s2.as_slice::<u8>().unwrap();
289 let vs1 = match l1.contiguous_offsets() {
290 Some((a, b)) => &vs1[a..b],
291 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
292 };
293 let vs2 = match l2.contiguous_offsets() {
294 Some((a, b)) => &vs2[a..b],
295 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
296 };
297 let result = self.bitwise(vs1, vs2);
298 let result = CpuStorage::U8(result);
299 Ok((result, l1.shape().clone()))
300 }
301 CpuStorage::U32(vs1) => {
302 let vs2 = s2.as_slice::<u32>().unwrap();
303 let vs1 = match l1.contiguous_offsets() {
304 Some((a, b)) => &vs1[a..b],
305 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
306 };
307 let vs2 = match l2.contiguous_offsets() {
308 Some((a, b)) => &vs2[a..b],
309 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
310 };
311 let result = self.bitwise(vs1, vs2);
312 let result = CpuStorage::U32(result);
313 Ok((result, l1.shape().clone()))
314 }
315 CpuStorage::I64(vs1) => {
316 let vs2 = s2.as_slice::<i64>().unwrap();
317 let vs1 = match l1.contiguous_offsets() {
318 Some((a, b)) => &vs1[a..b],
319 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
320 };
321 let vs2 = match l2.contiguous_offsets() {
322 Some((a, b)) => &vs2[a..b],
323 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
324 };
325 let result = self.bitwise(vs1, vs2);
326 let result = CpuStorage::I64(result);
327 Ok((result, l1.shape().clone()))
328 }
329 CpuStorage::I16(vs1) => {
330 let vs2 = s2.as_slice::<i16>().unwrap();
331 let vs1 = match l1.contiguous_offsets() {
332 Some((a, b)) => &vs1[a..b],
333 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
334 };
335 let vs2 = match l2.contiguous_offsets() {
336 Some((a, b)) => &vs2[a..b],
337 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
338 };
339 let result = self.bitwise(vs1, vs2);
340 let result = CpuStorage::I16(result);
341 Ok((result, l1.shape().clone()))
342 }
343 CpuStorage::I32(vs1) => {
344 let vs2 = s2.as_slice::<i32>().unwrap();
345 let vs1 = match l1.contiguous_offsets() {
346 Some((a, b)) => &vs1[a..b],
347 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
348 };
349 let vs2 = match l2.contiguous_offsets() {
350 Some((a, b)) => &vs2[a..b],
351 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
352 };
353 let result = self.bitwise(vs1, vs2);
354 let result = CpuStorage::I32(result);
355 Ok((result, l1.shape().clone()))
356 }
357 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
358 }
359 }
360
361 #[cfg(feature = "cuda")]
362 fn cuda_fwd(
363 &self,
364 s1: &CudaStorage,
365 l1: &Layout,
366 s2: &CudaStorage,
367 l2: &Layout,
368 ) -> Result<(CudaStorage, Shape)> {
369 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
370 return Err(Error::ShapeMismatchBinaryOp {
371 lhs: l1.shape().clone(),
372 rhs: l2.shape().clone(),
373 op: "bitwise-op",
374 });
375 }
376 if s1.dtype() != s2.dtype() {
377 return Err(Error::DTypeMismatchBinaryOp {
378 lhs: s1.dtype(),
379 rhs: s2.dtype(),
380 op: "bitwise-op",
381 });
382 }
383 if !l1.is_contiguous() {
384 candle_core::bail!("Input tensor s1 must be contiguous");
385 }
386 if !l2.is_contiguous() {
387 candle_core::bail!("Input tensor s2 must be contiguous");
388 }
389
390 let dev = s1.device().clone();
391 let (d_in1_ptr, d_in2_ptr, _d_in1_guard, _d_in2_guard, elem_count) = match s1.dtype() {
392 DType::U8 => {
393 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
394 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u8>()?, l2.start_offset());
395 let elem_count = l1.shape().elem_count();
396 (
397 d_in1 as *const std::ffi::c_void,
398 d_in2 as *const std::ffi::c_void,
399 d_in1_guard,
400 d_in2_guard,
401 elem_count,
402 )
403 }
404 DType::U32 => {
405 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u32>()?, l1.start_offset());
406 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u32>()?, l2.start_offset());
407 let elem_count = l1.shape().elem_count();
408 (
409 d_in1 as *const std::ffi::c_void,
410 d_in2 as *const std::ffi::c_void,
411 d_in1_guard,
412 d_in2_guard,
413 elem_count,
414 )
415 }
416 DType::I64 => {
417 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i64>()?, l1.start_offset());
418 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i64>()?, l2.start_offset());
419 let elem_count = l1.shape().elem_count();
420 (
421 d_in1 as *const std::ffi::c_void,
422 d_in2 as *const std::ffi::c_void,
423 d_in1_guard,
424 d_in2_guard,
425 elem_count,
426 )
427 }
428 DType::I32 => {
429 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
430 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i32>()?, l2.start_offset());
431 let elem_count = l1.shape().elem_count();
432 (
433 d_in1 as *const std::ffi::c_void,
434 d_in2 as *const std::ffi::c_void,
435 d_in1_guard,
436 d_in2_guard,
437 elem_count,
438 )
439 }
440 DType::I16 => {
441 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i16>()?, l1.start_offset());
442 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i16>()?, l2.start_offset());
443 let elem_count = l1.shape().elem_count();
444 (
445 d_in1 as *const std::ffi::c_void,
446 d_in2 as *const std::ffi::c_void,
447 d_in1_guard,
448 d_in2_guard,
449 elem_count,
450 )
451 }
452 other => {
453 return Err(Error::UnsupportedDTypeForOp(other, "bitwise"));
454 }
455 };
456 let dst = match s1.dtype() {
457 DType::U8 => {
458 let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
459 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
460 unsafe {
461 match self.op {
462 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
463 d_in1_ptr,
464 d_in2_ptr,
465 d_out_ptr as *mut c_void,
466 u32::try_from(elem_count)?,
467 ),
468 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
469 d_in1_ptr,
470 d_in2_ptr,
471 d_out_ptr as *mut c_void,
472 u32::try_from(elem_count)?,
473 ),
474 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
475 d_in1_ptr,
476 d_in2_ptr,
477 d_out_ptr as *mut c_void,
478 u32::try_from(elem_count)?,
479 ),
480 }
481 };
482 drop(d_out_guard);
483 CudaStorage::wrap_cuda_slice(d_out, dev)
484 }
485 DType::U32 => {
486 let d_out = unsafe { dev.alloc::<u32>(elem_count) }?;
487 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
488 unsafe {
489 match self.op {
490 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
491 d_in1_ptr,
492 d_in2_ptr,
493 d_out_ptr as *mut c_void,
494 u32::try_from(elem_count)?,
495 ),
496 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
497 d_in1_ptr,
498 d_in2_ptr,
499 d_out_ptr as *mut c_void,
500 u32::try_from(elem_count)?,
501 ),
502 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
503 d_in1_ptr,
504 d_in2_ptr,
505 d_out_ptr as *mut c_void,
506 u32::try_from(elem_count)?,
507 ),
508 }
509 };
510 drop(d_out_guard);
511 CudaStorage::wrap_cuda_slice(d_out, dev)
512 }
513 DType::I64 => {
514 let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
515 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
516 unsafe {
517 match self.op {
518 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
519 d_in1_ptr,
520 d_in2_ptr,
521 d_out_ptr as *mut c_void,
522 u32::try_from(elem_count)?,
523 ),
524 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
525 d_in1_ptr,
526 d_in2_ptr,
527 d_out_ptr as *mut c_void,
528 u32::try_from(elem_count)?,
529 ),
530 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
531 d_in1_ptr,
532 d_in2_ptr,
533 d_out_ptr as *mut c_void,
534 u32::try_from(elem_count)?,
535 ),
536 }
537 };
538 drop(d_out_guard);
539 CudaStorage::wrap_cuda_slice(d_out, dev)
540 }
541 DType::I32 => {
542 let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
543 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
544 unsafe {
545 match self.op {
546 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
547 d_in1_ptr,
548 d_in2_ptr,
549 d_out_ptr as *mut c_void,
550 u32::try_from(elem_count)?,
551 ),
552 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
553 d_in1_ptr,
554 d_in2_ptr,
555 d_out_ptr as *mut c_void,
556 u32::try_from(elem_count)?,
557 ),
558 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
559 d_in1_ptr,
560 d_in2_ptr,
561 d_out_ptr as *mut c_void,
562 u32::try_from(elem_count)?,
563 ),
564 }
565 };
566 drop(d_out_guard);
567 CudaStorage::wrap_cuda_slice(d_out, dev)
568 }
569 _ => unreachable!(),
570 };
571 Ok((dst, l1.shape().clone()))
572 }
573
574 #[cfg(feature = "metal")]
575 fn metal_fwd(
576 &self,
577 s1: &candle_core::MetalStorage,
578 l1: &Layout,
579 s2: &candle_core::MetalStorage,
580 l2: &Layout,
581 ) -> Result<(candle_core::MetalStorage, Shape)> {
582 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
583 return Err(Error::ShapeMismatchBinaryOp {
584 lhs: l1.shape().clone(),
585 rhs: l2.shape().clone(),
586 op: "bitwise-op",
587 });
588 }
589 if s1.dtype() != s2.dtype() {
590 return Err(Error::DTypeMismatchBinaryOp {
591 lhs: s1.dtype(),
592 rhs: s2.dtype(),
593 op: "bitwise-op",
594 });
595 }
596 if !l1.is_contiguous() {
597 candle_core::bail!("Input tensor s1 must be contiguous");
598 }
599 if !l2.is_contiguous() {
600 candle_core::bail!("Input tensor s2 must be contiguous");
601 }
602
603 let encoder = s1.device().command_encoder()?;
604 encoder.set_label("bitwise-op");
605
606 let device = s1.device();
607
608 let out_shape = l1.shape().clone();
609
610 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
611
612 match self.op {
613 BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
614 device.device(),
615 &encoder,
616 &crate::metal_kernels::Kernels::new(),
617 s1.dtype(),
618 s1.buffer(),
619 s2.buffer(),
620 l1.start_offset() * s1.dtype().size_in_bytes(),
621 l2.start_offset() * s2.dtype().size_in_bytes(),
622 out_shape.elem_count(),
623 &output,
624 )
625 .map_err(candle_core::Error::wrap)?,
626 BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
627 device.device(),
628 &encoder,
629 &crate::metal_kernels::Kernels::new(),
630 s1.dtype(),
631 s1.buffer(),
632 s2.buffer(),
633 l1.start_offset() * s1.dtype().size_in_bytes(),
634 l2.start_offset() * s2.dtype().size_in_bytes(),
635 out_shape.elem_count(),
636 &output,
637 )
638 .map_err(candle_core::Error::wrap)?,
639 BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
640 device.device(),
641 &encoder,
642 &crate::metal_kernels::Kernels::new(),
643 s1.dtype(),
644 s1.buffer(),
645 s2.buffer(),
646 l1.start_offset() * s1.dtype().size_in_bytes(),
647 l2.start_offset() * s2.dtype().size_in_bytes(),
648 out_shape.elem_count(),
649 &output,
650 )
651 .map_err(candle_core::Error::wrap)?,
652 }
653
654 let newstorage = candle_core::MetalStorage::new(
655 output,
656 device.clone(),
657 out_shape.elem_count(),
658 s1.dtype(),
659 );
660 Ok((newstorage, out_shape))
661 }
662}
663
664struct BitWiseUnary {
665 pub op: BitWiseUnaryOpEnum,
666}
667
668impl BitWiseUnary {
669 pub fn new(op: BitWiseUnaryOpEnum) -> Self {
670 Self { op }
671 }
672
673 fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
674 vs1.into_par_iter()
675 .map(|v1| match self.op {
676 BitWiseUnaryOpEnum::Not => !*v1,
677 })
678 .collect()
679 }
680}
681
682impl CustomOp1 for BitWiseUnary {
683 fn name(&self) -> &'static str {
684 "bitwise-unary"
685 }
686
687 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
688 if !l1.is_contiguous() {
689 candle_core::bail!("Input tensor s1 must be contiguous");
690 }
691
692 match s1 {
693 CpuStorage::U8(vs1) => {
694 let vs1 = match l1.contiguous_offsets() {
695 Some((a, b)) => &vs1[a..b],
696 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
697 };
698 let result = self.bitwise(vs1);
699 let result = CpuStorage::U8(result);
700 Ok((result, l1.shape().clone()))
701 }
702 CpuStorage::U32(vs1) => {
703 let vs1 = match l1.contiguous_offsets() {
704 Some((a, b)) => &vs1[a..b],
705 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
706 };
707 let result = self.bitwise(vs1);
708 let result = CpuStorage::U32(result);
709 Ok((result, l1.shape().clone()))
710 }
711 CpuStorage::I64(vs1) => {
712 let vs1 = match l1.contiguous_offsets() {
713 Some((a, b)) => &vs1[a..b],
714 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
715 };
716 let result = self.bitwise(vs1);
717 let result = CpuStorage::I64(result);
718 Ok((result, l1.shape().clone()))
719 }
720 CpuStorage::I16(vs1) => {
721 let vs1 = match l1.contiguous_offsets() {
722 Some((a, b)) => &vs1[a..b],
723 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
724 };
725 let result = self.bitwise(vs1);
726 let result = CpuStorage::I16(result);
727 Ok((result, l1.shape().clone()))
728 }
729 CpuStorage::I32(vs1) => {
730 let vs1 = match l1.contiguous_offsets() {
731 Some((a, b)) => &vs1[a..b],
732 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
733 };
734 let result = self.bitwise(vs1);
735 let result = CpuStorage::I32(result);
736 Ok((result, l1.shape().clone()))
737 }
738 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
739 }
740 }
741
742 #[cfg(feature = "cuda")]
743 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
744 todo!()
745 }
746
747 #[cfg(feature = "metal")]
748 fn metal_fwd(
749 &self,
750 s1: &candle_core::MetalStorage,
751 l1: &Layout,
752 ) -> Result<(candle_core::MetalStorage, Shape)> {
753 if !l1.is_contiguous() {
754 candle_core::bail!("Input tensor s1 must be contiguous");
755 }
756
757 let encoder = s1.device().command_encoder()?;
758 encoder.set_label("bitwise-unary-op");
759
760 let device = s1.device();
761
762 let out_shape = l1.shape().clone();
763
764 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
765
766 match self.op {
767 BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
768 device.device(),
769 &encoder,
770 &crate::metal_kernels::Kernels::new(),
771 s1.dtype(),
772 s1.buffer(),
773 l1.start_offset() * s1.dtype().size_in_bytes(),
774 out_shape.elem_count(),
775 &output,
776 )
777 .map_err(candle_core::Error::wrap)?,
778 }
779
780 let newstorage = candle_core::MetalStorage::new(
781 output,
782 device.clone(),
783 out_shape.elem_count(),
784 s1.dtype(),
785 );
786 Ok((newstorage, out_shape))
787 }
788}
789
790#[allow(dead_code)]
791pub trait BitWiseOp {
792 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
793 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
794 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
795 fn bitwise_not(&self) -> Result<Tensor>;
796}
797
798impl BitWiseOp for Tensor {
799 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
800 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
801 }
802
803 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
804 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
805 }
806
807 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
808 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
809 }
810
811 fn bitwise_not(&self) -> Result<Tensor> {
812 self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
813 }
814}
815
816#[allow(unused)]
819struct ArgSort {
821 axis: usize,
822}
823
824#[allow(unused)]
825struct Sort {
827 axis: usize,
828}
829
830impl CustomOp1 for ArgSort {
831 fn name(&self) -> &'static str {
832 "argsort"
833 }
834
835 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
837 candle_core::bail!("ArgSort is not implemented for the CPU backend");
838 }
839
840 #[cfg(feature = "cuda")]
842 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
843 candle_core::bail!("ArgSort is not implemented for the CUDA backend");
844 }
845
846 #[cfg(feature = "metal")]
848 fn metal_fwd(
849 &self,
850 s1: &candle_core::MetalStorage,
851 l1: &Layout,
852 ) -> Result<(candle_core::MetalStorage, Shape)> {
853 if !l1.is_contiguous() {
855 candle_core::bail!("Input tensor s1 must be contiguous");
856 }
857
858 let encoder = s1.device().command_encoder()?;
860 encoder.set_label("argsort");
861
862 let device = s1.device();
863 let out_shape = l1.shape().clone();
864 let elem_count = out_shape.elem_count();
865
866 let output = device.new_buffer(elem_count, candle_core::DType::U32, "argsort")?;
868
869 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
873
874 let dims = l1.dims();
875 let size_sorted_axis = dims[self.axis];
876 let n_rows = l1.shape().elem_count() / size_sorted_axis;
877
878 let tn = 4usize;
880 let mut bn = match size_sorted_axis.div_ceil(tn) {
881 v if v > 256 => 512,
882 v if v > 128 => 256,
883 v if v > 64 => 128,
884 v if v > 32 => 64,
885 _ => 32,
886 };
887 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
888 bn = 256;
889 }
890 let n_per_block = bn * tn;
891 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
892
893 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
895
896 let sort_args = crate::metal_kernels::SortArgs {
900 axis: self.axis,
901 shape: l1.dims(),
902 strides: l1.stride(),
903 out_shape: l1.dims(), out_strides: l1.stride(),
905 in_contiguous: l1.is_contiguous(),
906 in_ty: s1.dtype(),
907 out_ty: candle_core::DType::U32,
908 src: s1.buffer(),
909 src_offset: l1.start_offset(), dst: &output,
911 bn,
912 tn,
913 n_blocks,
914 };
915
916 crate::metal_kernels::call_argsort(
918 device.device(),
919 &encoder, &crate::metal_kernels::Kernels::new(),
921 &sort_args,
922 &scratch,
923 )
924 .map_err(candle_core::Error::wrap)?;
925
926 let newstorage = candle_core::MetalStorage::new(
928 output,
929 device.clone(),
930 elem_count,
931 candle_core::DType::U32,
932 );
933 Ok((newstorage, out_shape))
934 }
935}
936
937impl CustomOp1 for Sort {
938 fn name(&self) -> &'static str {
939 "sort"
940 }
941
942 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
944 candle_core::bail!("Sort is not implemented for the CPU backend");
945 }
946
947 #[cfg(feature = "cuda")]
949 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
950 candle_core::bail!("Sort is not implemented for the CUDA backend");
951 }
952
953 #[cfg(feature = "metal")]
955 fn metal_fwd(
956 &self,
957 s1: &candle_core::MetalStorage,
958 l1: &Layout,
959 ) -> Result<(candle_core::MetalStorage, Shape)> {
960 if !l1.is_contiguous() {
962 candle_core::bail!("Input tensor s1 must be contiguous");
963 }
964
965 let encoder = s1.device().command_encoder()?;
967 encoder.set_label("sort");
968
969 let device = s1.device();
970 let out_shape = l1.shape().clone();
971 let elem_count = out_shape.elem_count();
972
973 let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
975
976 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
980
981 let dims = l1.dims();
982 let size_sorted_axis = dims[self.axis];
983 let n_rows = l1.shape().elem_count() / size_sorted_axis;
984
985 let tn = 4usize;
987 let mut bn = match size_sorted_axis.div_ceil(tn) {
988 v if v > 256 => 512,
989 v if v > 128 => 256,
990 v if v > 64 => 128,
991 v if v > 32 => 64,
992 _ => 32,
993 };
994 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
995 bn = 256;
996 }
997 let n_per_block = bn * tn;
998 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
999
1000 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
1002
1003 let sort_args = crate::metal_kernels::SortArgs {
1007 axis: self.axis,
1008 shape: l1.dims(),
1009 strides: l1.stride(),
1010 out_shape: l1.dims(), out_strides: l1.stride(),
1012 in_contiguous: l1.is_contiguous(),
1013 in_ty: s1.dtype(),
1014 out_ty: s1.dtype(),
1015 src: s1.buffer(),
1016 src_offset: l1.start_offset(), dst: &output,
1018 bn,
1019 tn,
1020 n_blocks,
1021 };
1022
1023 crate::metal_kernels::call_sort(
1025 device.device(),
1026 &encoder, &crate::metal_kernels::Kernels::new(),
1028 &sort_args,
1029 &scratch,
1030 )
1031 .map_err(candle_core::Error::wrap)?;
1032
1033 let newstorage =
1035 candle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1036 Ok((newstorage, out_shape))
1037 }
1038}
1039
1040pub trait SortOp {
1042 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1044 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1046}
1047
1048impl SortOp for Tensor {
1049 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1050 if self.device().is_cpu() || self.device().is_cuda() {
1051 return self.arg_sort_last_dim(true);
1052 }
1053 self.apply_op1_no_bwd(&ArgSort {
1054 axis: axis.to_index(self.shape(), "argsort")?,
1055 })
1056 }
1057
1058 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1059 if self.device().is_cpu() || self.device().is_cuda() {
1060 return Ok(self.sort_last_dim(true)?.0);
1061 }
1062 self.apply_op1_no_bwd(&Sort {
1063 axis: axis.to_index(self.shape(), "sort")?,
1064 })
1065 }
1066}
1067
1068struct NonZero;
1069
1070impl NonZero {
1071 fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1073 let n = layout.dims().len();
1074 let mut result = Vec::new();
1075 let mut indices = vec![0u32; n];
1076 for (i, v) in vs.iter().enumerate() {
1077 if !v.is_zero() {
1078 let mut idx = i;
1079 for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1080 let d = idx % dim;
1081 indices[dim_index] = u32::try_from(d).unwrap();
1082 idx /= dim;
1083 }
1084 result.extend_from_slice(&indices);
1085 }
1086 }
1087 result
1088 }
1089}
1090
1091#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1092mod cuda_ops_cccl2 {
1093 use super::*;
1094
1095 pub(super) fn count_nonzero_cuda(
1096 dtype: candle_core::DType,
1097 d_in: *const c_void,
1098 n: u32,
1099 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1100 ) -> u32 {
1101 unsafe {
1102 match dtype {
1103 candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1104 candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1105 candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1106 candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1107 candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1108 candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1109 candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1110 candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1111 candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1112 _ => unreachable!(),
1113 }
1114 }
1115 }
1116
1117 #[allow(clippy::too_many_arguments)]
1118 pub(super) fn nonzero_cuda(
1119 dtype: candle_core::DType,
1120 d_in: *const c_void,
1121 n: u32,
1122 num_nonzero: u32,
1123 dims: *const c_void,
1124 num_dims: u32,
1125 d_out: *mut c_void,
1126 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1127 ) {
1128 unsafe {
1129 match dtype {
1130 candle_core::DType::U8 => {
1131 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1132 }
1133 candle_core::DType::U32 => {
1134 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1135 }
1136 candle_core::DType::I64 => {
1137 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1138 }
1139 candle_core::DType::I32 => {
1140 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1141 }
1142 candle_core::DType::I16 => {
1143 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1144 }
1145 candle_core::DType::BF16 => {
1146 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1147 }
1148 candle_core::DType::F16 => {
1149 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1150 }
1151 candle_core::DType::F32 => {
1152 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1153 }
1154 candle_core::DType::F64 => {
1155 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1156 }
1157 _ => unreachable!(),
1158 }
1159 }
1160 }
1161}
1162
1163#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1164mod cuda_ops_cccl3 {
1165 use super::*;
1166
1167 pub(super) fn count_nonzero_cuda(
1168 dtype: candle_core::DType,
1169 d_in: *const c_void,
1170 n: u32,
1171 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1172 ) -> u32 {
1173 unsafe {
1174 match dtype {
1175 candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1176 candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1177 candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1178 candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1179 candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1180 candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1181 candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1182 candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1183 candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1184 _ => unreachable!(),
1185 }
1186 }
1187 }
1188
1189 #[allow(clippy::too_many_arguments)]
1190 pub(super) fn nonzero_cuda(
1191 dtype: candle_core::DType,
1192 d_in: *const c_void,
1193 n: u32,
1194 num_nonzero: u32,
1195 dims: *const c_void,
1196 num_dims: u32,
1197 d_out: *mut c_void,
1198 stream: candle_core::cuda::cudarc::driver::sys::CUstream,
1199 ) {
1200 unsafe {
1201 match dtype {
1202 candle_core::DType::U8 => {
1203 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1204 }
1205 candle_core::DType::U32 => {
1206 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1207 }
1208 candle_core::DType::I64 => {
1209 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1210 }
1211 candle_core::DType::I32 => {
1212 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1213 }
1214 candle_core::DType::I16 => {
1215 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1216 }
1217 candle_core::DType::BF16 => {
1218 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1219 }
1220 candle_core::DType::F16 => {
1221 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1222 }
1223 candle_core::DType::F32 => {
1224 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1225 }
1226 candle_core::DType::F64 => {
1227 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1228 }
1229 _ => unreachable!(),
1230 }
1231 }
1232 }
1233}
1234
1235#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1236use cuda_ops_cccl2::{count_nonzero_cuda, nonzero_cuda};
1237#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1238use cuda_ops_cccl3::{count_nonzero_cuda, nonzero_cuda};
1239
1240impl CustomOp1 for NonZero {
1241 fn name(&self) -> &'static str {
1242 "nonzero"
1243 }
1244
1245 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1246 if !layout.is_contiguous() {
1247 return Err(Error::RequiresContiguous { op: "nonzero" });
1248 }
1249 let result = match storage {
1250 candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1251 candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1252 candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1253 candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1254 candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1255 candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1256 candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1257 candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1258 candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1259 _ => unreachable!(),
1260 };
1261 let index_len = layout.dims().len();
1262 let result_len = result.len() / index_len;
1263 let result = CpuStorage::U32(result);
1264 let shape = Shape::from_dims(&[result_len, index_len]);
1265 Ok((result, shape))
1266 }
1267
1268 #[cfg(feature = "cuda")]
1269 fn cuda_fwd(
1270 &self,
1271 storage: &candle_core::CudaStorage,
1272 layout: &Layout,
1273 ) -> Result<(candle_core::CudaStorage, Shape)> {
1274 if !layout.is_contiguous() {
1275 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1276 }
1277 let dev = storage.device().clone();
1278 let (d_in, _d_in_guard) = match storage.dtype() {
1279 candle_core::DType::U8 => {
1280 let slice = storage.as_cuda_slice::<u8>()?;
1281 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1282 (d_in as *const std::ffi::c_void, d_in_guard)
1283 }
1284 candle_core::DType::U32 => {
1285 let slice = storage.as_cuda_slice::<u32>()?;
1286 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1287 (d_in as *const std::ffi::c_void, d_in_guard)
1288 }
1289 candle_core::DType::I32 => {
1290 let slice = storage.as_cuda_slice::<i32>()?;
1291 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1292 (d_in as *const std::ffi::c_void, d_in_guard)
1293 }
1294 candle_core::DType::I16 => {
1295 let slice = storage.as_cuda_slice::<i16>()?;
1296 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1297 (d_in as *const std::ffi::c_void, d_in_guard)
1298 }
1299 candle_core::DType::I64 => {
1300 let slice = storage.as_cuda_slice::<i64>()?;
1301 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1302 (d_in as *const std::ffi::c_void, d_in_guard)
1303 }
1304 candle_core::DType::BF16 => {
1305 let slice = storage.as_cuda_slice::<half::bf16>()?;
1306 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1307 (d_in as *const std::ffi::c_void, d_in_guard)
1308 }
1309 candle_core::DType::F16 => {
1310 let slice = storage.as_cuda_slice::<half::f16>()?;
1311 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1312 (d_in as *const std::ffi::c_void, d_in_guard)
1313 }
1314 candle_core::DType::F32 => {
1315 let slice = storage.as_cuda_slice::<f32>()?;
1316 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1317 (d_in as *const std::ffi::c_void, d_in_guard)
1318 }
1319 candle_core::DType::F64 => {
1320 let slice = storage.as_cuda_slice::<f64>()?;
1321 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1322 (d_in as *const std::ffi::c_void, d_in_guard)
1323 }
1324 _ => unreachable!(),
1325 };
1326 let n = layout.shape().elem_count();
1327
1328 let num_nonzero = count_nonzero_cuda(
1329 storage.dtype(),
1330 d_in,
1331 u32::try_from(n)?,
1332 dev.cuda_stream().cu_stream(),
1333 );
1334 let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1335 .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1336 if num_nonzero != 0 {
1337 let (d_out, _d_out_guard) = d_out.device_ptr(d_out.stream());
1338 let dims = layout
1339 .dims()
1340 .iter()
1341 .map(|&x| u32::try_from(x).unwrap())
1342 .collect::<Vec<u32>>();
1343 let mut d_dims = unsafe { dev.alloc::<u32>(dims.len()) }?;
1344 dev.memcpy_htod(&dims, &mut d_dims)?;
1345 let (d_dims_ptr, _d_dims_guard) = d_dims.device_ptr(d_dims.stream());
1346 nonzero_cuda(
1347 storage.dtype(),
1348 d_in,
1349 u32::try_from(n)?,
1350 num_nonzero,
1351 d_dims_ptr as *const c_void,
1352 u32::try_from(layout.dims().len())?,
1353 d_out as *mut c_void,
1354 dev.cuda_stream().cu_stream(),
1355 );
1356 }
1357 let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1358 let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
1359 Ok((dst, shape))
1360 }
1361}
1362
1363pub trait NonZeroOp {
1364 fn nonzero(&self) -> Result<Tensor>;
1365}
1366
1367impl NonZeroOp for Tensor {
1368 #[cfg(feature = "metal")]
1369 fn nonzero(&self) -> Result<Tensor> {
1370 if !self.is_contiguous() {
1371 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1372 }
1373 let original_device = self.device();
1374 self.to_device(&candle_core::Device::Cpu)?
1375 .apply_op1_no_bwd(&NonZero)?
1376 .to_device(original_device)
1377 }
1378
1379 #[cfg(not(feature = "metal"))]
1380 fn nonzero(&self) -> Result<Tensor> {
1381 if !self.is_contiguous() {
1382 return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
1383 }
1384 self.apply_op1_no_bwd(&NonZero)
1385 }
1386}
1387
1388struct CumSum {
1389 inclusive: bool,
1390 reverse: bool,
1391 axis: usize,
1392}
1393
1394impl CustomOp1 for CumSum {
1395 fn name(&self) -> &'static str {
1396 "cumsum"
1397 }
1398
1399 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1400 use std::ops::Add;
1401 if !l1.is_contiguous() {
1402 candle_core::bail!("Input tensor s1 must be contiguous");
1403 }
1404 let dims = l1.dims();
1405 let axis = self.axis;
1406 let axis_len = dims[axis];
1407 let (start, end) = l1
1408 .contiguous_offsets()
1409 .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1410
1411 macro_rules! scan_block {
1413 ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1414 let vs: &[$ty] = $vt;
1415 let input = &vs[start..end];
1416 let count = input.len() / axis_len;
1417 let mut result = Vec::<$ty>::with_capacity(input.len());
1418 if !self.reverse {
1419 if self.inclusive {
1420 for block in 0..count {
1421 let base = block * axis_len;
1422 let mut sum = input[base];
1423 result.push(sum);
1424 for j in 1..axis_len {
1425 sum = sum.$add(input[base + j]);
1426 result.push(sum);
1427 }
1428 }
1429 } else {
1430 let init: $ty = $init;
1431 for block in 0..count {
1432 let base = block * axis_len;
1433 let mut sum = init;
1434 for j in 0..axis_len {
1435 result.push(sum);
1436 sum = sum.$add(input[base + j]);
1437 }
1438 }
1439 }
1440 } else {
1441 if self.inclusive {
1442 for block in 0..count {
1443 let base = block * axis_len;
1444 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1445 let mut sum = input[base + axis_len - 1];
1446 temp.push(sum);
1447 for k in 1..axis_len {
1448 let idx = axis_len - 1 - k;
1449 sum = sum.$add(input[base + idx]);
1450 temp.push(sum);
1451 }
1452 temp.reverse();
1453 result.extend(temp);
1454 }
1455 } else {
1456 let init: $ty = $init;
1457 for block in 0..count {
1458 let base = block * axis_len;
1459 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1460 let mut sum = init;
1461 for k in 0..axis_len {
1462 let idx = axis_len - 1 - k;
1463 temp.push(sum);
1464 sum = sum.$add(input[base + idx]);
1465 }
1466 temp.reverse();
1467 result.extend(temp);
1468 }
1469 }
1470 }
1471 result
1472 }};
1473 }
1474 match s1 {
1475 CpuStorage::U8(vs) => {
1476 let result = scan_block!(vs, u8, wrapping_add, 0u8);
1477 Ok((CpuStorage::U8(result), l1.shape().clone()))
1478 }
1479 CpuStorage::I16(vs) => {
1480 let result = scan_block!(vs, i16, add, 0i16);
1481 Ok((CpuStorage::I16(result), l1.shape().clone()))
1482 }
1483 CpuStorage::U32(vs) => {
1484 let result = scan_block!(vs, u32, wrapping_add, 0u32);
1485 Ok((CpuStorage::U32(result), l1.shape().clone()))
1486 }
1487 CpuStorage::I32(vs) => {
1488 let result = scan_block!(vs, i32, add, 0i32);
1489 Ok((CpuStorage::I32(result), l1.shape().clone()))
1490 }
1491 CpuStorage::I64(vs) => {
1492 let result = scan_block!(vs, i64, add, 0i64);
1493 Ok((CpuStorage::I64(result), l1.shape().clone()))
1494 }
1495 CpuStorage::F32(vs) => {
1496 let result = scan_block!(vs, f32, add, 0.0f32);
1497 Ok((CpuStorage::F32(result), l1.shape().clone()))
1498 }
1499 CpuStorage::F64(vs) => {
1500 let result = scan_block!(vs, f64, add, 0.0f64);
1501 Ok((CpuStorage::F64(result), l1.shape().clone()))
1502 }
1503 _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1504 }
1505 }
1506
1507 #[cfg(feature = "cuda")]
1508 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1509 todo!()
1510 }
1511
1512 #[cfg(feature = "metal")]
1513 fn metal_fwd(
1514 &self,
1515 s1: &candle_core::MetalStorage,
1516 l1: &Layout,
1517 ) -> Result<(candle_core::MetalStorage, Shape)> {
1518 use crate::metal_kernels::ScanType;
1519
1520 let encoder = s1.device().command_encoder()?;
1521 encoder.set_label("cumsum");
1522
1523 let device = s1.device();
1524
1525 let out_shape = l1.shape().clone();
1526
1527 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1528
1529 crate::metal_kernels::call_scan(
1530 device.device(),
1531 &encoder,
1532 &crate::metal_kernels::Kernels::new(),
1533 s1.dtype(),
1534 ScanType::Sum,
1535 s1.buffer(),
1536 l1.start_offset() * s1.dtype().size_in_bytes(),
1537 self.axis,
1538 l1.dims(),
1539 l1.stride(),
1540 self.reverse,
1541 self.inclusive,
1542 &output,
1543 )
1544 .map_err(candle_core::Error::wrap)?;
1545
1546 let newstorage = candle_core::MetalStorage::new(
1547 output,
1548 device.clone(),
1549 out_shape.elem_count(),
1550 s1.dtype(),
1551 );
1552 Ok((newstorage, out_shape))
1553 }
1554}
1555
1556#[allow(dead_code)]
1557pub trait CumSumOp {
1558 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1560
1561 fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1562 -> Result<Tensor>;
1563}
1564
1565impl CumSumOp for Tensor {
1566 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1567 self.fast_cumsum_config(axis, false, false)
1568 }
1569
1570 fn fast_cumsum_config<D: Dim>(
1571 &self,
1572 axis: D,
1573 inclusive: bool,
1574 reverse: bool,
1575 ) -> Result<Tensor> {
1576 self.apply_op1_no_bwd(&CumSum {
1577 inclusive,
1578 reverse,
1579 axis: axis.to_index(self.shape(), "cumsum")?,
1580 })
1581 }
1582}
1583
1584#[cfg(feature = "cuda")]
1588pub fn gptoss_swiglu_fused(gate: &Tensor, up: &Tensor, alpha: f32, limit: f32) -> Result<Tensor> {
1589 use half::{bf16, f16};
1590
1591 let gate = gate.contiguous()?;
1592 let up = up.contiguous()?;
1593
1594 if gate.shape() != up.shape() {
1595 candle_core::bail!(
1596 "gptoss_swiglu: gate and up must have same shape, got {:?} vs {:?}",
1597 gate.shape(),
1598 up.shape()
1599 );
1600 }
1601
1602 let device = match gate.device() {
1603 candle_core::Device::Cuda(dev) => dev,
1604 _ => candle_core::bail!("gptoss_swiglu requires CUDA device"),
1605 };
1606
1607 let n_elements = gate.elem_count();
1608 let dtype = gate.dtype();
1609
1610 let gate_storage = gate.storage_and_layout().0;
1611 let up_storage = up.storage_and_layout().0;
1612
1613 let gate_cuda = match &*gate_storage {
1614 candle_core::Storage::Cuda(s) => s,
1615 _ => candle_core::bail!("Expected CUDA storage for gate"),
1616 };
1617 let up_cuda = match &*up_storage {
1618 candle_core::Storage::Cuda(s) => s,
1619 _ => candle_core::bail!("Expected CUDA storage for up"),
1620 };
1621
1622 let stream = device.cuda_stream().cu_stream();
1623
1624 match dtype {
1625 DType::F16 => {
1626 let output = device.alloc_zeros::<f16>(n_elements)?;
1627 let gate_slice = gate_cuda.as_cuda_slice::<f16>()?;
1628 let up_slice = up_cuda.as_cuda_slice::<f16>()?;
1629
1630 let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1631 let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1632 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1633
1634 unsafe {
1635 ffi::gptoss_swiglu_f16(
1636 gate_ptr as *const c_void,
1637 up_ptr as *const c_void,
1638 out_ptr as *mut c_void,
1639 n_elements as u32,
1640 alpha,
1641 limit,
1642 stream,
1643 );
1644 }
1645
1646 drop(_o_guard);
1647 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1648 Ok(Tensor::from((
1649 candle_core::Storage::Cuda(out_storage),
1650 gate.shape().clone(),
1651 )))
1652 }
1653 DType::BF16 => {
1654 let output = device.alloc_zeros::<bf16>(n_elements)?;
1655 let gate_slice = gate_cuda.as_cuda_slice::<bf16>()?;
1656 let up_slice = up_cuda.as_cuda_slice::<bf16>()?;
1657
1658 let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1659 let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1660 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1661
1662 unsafe {
1663 ffi::gptoss_swiglu_bf16(
1664 gate_ptr as *const c_void,
1665 up_ptr as *const c_void,
1666 out_ptr as *mut c_void,
1667 n_elements as u32,
1668 alpha,
1669 limit,
1670 stream,
1671 );
1672 }
1673
1674 drop(_o_guard);
1675 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1676 Ok(Tensor::from((
1677 candle_core::Storage::Cuda(out_storage),
1678 gate.shape().clone(),
1679 )))
1680 }
1681 DType::F32 => {
1682 let output = device.alloc_zeros::<f32>(n_elements)?;
1683 let gate_slice = gate_cuda.as_cuda_slice::<f32>()?;
1684 let up_slice = up_cuda.as_cuda_slice::<f32>()?;
1685
1686 let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1687 let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1688 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1689
1690 unsafe {
1691 ffi::gptoss_swiglu_f32(
1692 gate_ptr as *const c_void,
1693 up_ptr as *const c_void,
1694 out_ptr as *mut c_void,
1695 n_elements as u32,
1696 alpha,
1697 limit,
1698 stream,
1699 );
1700 }
1701
1702 drop(_o_guard);
1703 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1704 Ok(Tensor::from((
1705 candle_core::Storage::Cuda(out_storage),
1706 gate.shape().clone(),
1707 )))
1708 }
1709 _ => candle_core::bail!("gptoss_swiglu: unsupported dtype {:?}", dtype),
1710 }
1711}
1712
1713#[cfg(feature = "cuda")]
1725pub fn gptoss_swiglu_interleaved(
1726 gate_up: &Tensor,
1727 intermediate_size: usize,
1728 alpha: f32,
1729 limit: f32,
1730) -> Result<Tensor> {
1731 use half::{bf16, f16};
1732 use std::ffi::c_void;
1733
1734 let gate_up = gate_up.contiguous()?;
1735
1736 let dims = gate_up.dims();
1737 if dims.len() != 3 || dims[2] != 2 {
1738 candle_core::bail!(
1739 "gptoss_swiglu_interleaved: expected gate_up shape [N, intermediate_size, 2], got {:?}",
1740 dims
1741 );
1742 }
1743
1744 let n = dims[0]; let device = match gate_up.device() {
1746 candle_core::Device::Cuda(dev) => dev,
1747 _ => candle_core::bail!("gptoss_swiglu_interleaved requires CUDA device"),
1748 };
1749
1750 let dtype = gate_up.dtype();
1751 let n_output_elements = n * intermediate_size;
1752
1753 let gate_up_storage = gate_up.storage_and_layout().0;
1754 let gate_up_cuda = match &*gate_up_storage {
1755 candle_core::Storage::Cuda(s) => s,
1756 _ => candle_core::bail!("Expected CUDA storage for gate_up"),
1757 };
1758
1759 let stream = device.cuda_stream().cu_stream();
1760
1761 match dtype {
1762 DType::F16 => {
1763 let output = device.alloc_zeros::<f16>(n_output_elements)?;
1764 let gate_up_slice = gate_up_cuda.as_cuda_slice::<f16>()?;
1765
1766 let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1767 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1768
1769 unsafe {
1770 ffi::gptoss_swiglu_interleaved_f16(
1771 gate_up_ptr as *const c_void,
1772 out_ptr as *mut c_void,
1773 n as u32,
1774 intermediate_size as u32,
1775 alpha,
1776 limit,
1777 stream,
1778 );
1779 }
1780
1781 drop(_o_guard);
1782 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1783 Ok(Tensor::from((
1784 candle_core::Storage::Cuda(out_storage),
1785 Shape::from(vec![n, intermediate_size]),
1786 )))
1787 }
1788 DType::BF16 => {
1789 let output = device.alloc_zeros::<bf16>(n_output_elements)?;
1790 let gate_up_slice = gate_up_cuda.as_cuda_slice::<bf16>()?;
1791
1792 let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1793 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1794
1795 unsafe {
1796 ffi::gptoss_swiglu_interleaved_bf16(
1797 gate_up_ptr as *const c_void,
1798 out_ptr as *mut c_void,
1799 n as u32,
1800 intermediate_size as u32,
1801 alpha,
1802 limit,
1803 stream,
1804 );
1805 }
1806
1807 drop(_o_guard);
1808 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1809 Ok(Tensor::from((
1810 candle_core::Storage::Cuda(out_storage),
1811 Shape::from(vec![n, intermediate_size]),
1812 )))
1813 }
1814 DType::F32 => {
1815 let output = device.alloc_zeros::<f32>(n_output_elements)?;
1816 let gate_up_slice = gate_up_cuda.as_cuda_slice::<f32>()?;
1817
1818 let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1819 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1820
1821 unsafe {
1822 ffi::gptoss_swiglu_interleaved_f32(
1823 gate_up_ptr as *const c_void,
1824 out_ptr as *mut c_void,
1825 n as u32,
1826 intermediate_size as u32,
1827 alpha,
1828 limit,
1829 stream,
1830 );
1831 }
1832
1833 drop(_o_guard);
1834 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1835 Ok(Tensor::from((
1836 candle_core::Storage::Cuda(out_storage),
1837 Shape::from(vec![n, intermediate_size]),
1838 )))
1839 }
1840 _ => candle_core::bail!("gptoss_swiglu_interleaved: unsupported dtype {:?}", dtype),
1841 }
1842}
1843
1844#[cfg(feature = "cuda")]
1856pub fn softmax_with_sinks(
1857 logits: &Tensor,
1858 sinks: &Tensor,
1859 mask: Option<&Tensor>,
1860) -> Result<Tensor> {
1861 use half::{bf16, f16};
1862 use std::ffi::c_void;
1863
1864 let logits = logits.contiguous()?;
1865 let sinks = sinks.contiguous()?;
1866
1867 let dims = logits.dims();
1868 if dims.len() != 4 {
1869 candle_core::bail!(
1870 "softmax_with_sinks: expected logits to have 4 dims [b, h, q, k], got {:?}",
1871 dims
1872 );
1873 }
1874
1875 let batch_size = dims[0];
1876 let num_heads = dims[1];
1877 let q_len = dims[2];
1878 let k_len = dims[3];
1879
1880 if sinks.dims() != [num_heads] {
1881 candle_core::bail!(
1882 "softmax_with_sinks: expected sinks shape [{}], got {:?}",
1883 num_heads,
1884 sinks.dims()
1885 );
1886 }
1887
1888 let device = match logits.device() {
1889 candle_core::Device::Cuda(dev) => dev,
1890 _ => candle_core::bail!("softmax_with_sinks requires CUDA device"),
1891 };
1892
1893 let dtype = logits.dtype();
1894 let n_elements = logits.elem_count();
1895
1896 let logits_storage = logits.storage_and_layout().0;
1897 let sinks_storage = sinks.storage_and_layout().0;
1898
1899 let logits_cuda = match &*logits_storage {
1900 candle_core::Storage::Cuda(s) => s,
1901 _ => candle_core::bail!("Expected CUDA storage for logits"),
1902 };
1903 let sinks_cuda = match &*sinks_storage {
1904 candle_core::Storage::Cuda(s) => s,
1905 _ => candle_core::bail!("Expected CUDA storage for sinks"),
1906 };
1907
1908 let mask = if let Some(m) = mask {
1910 Some(m.contiguous()?)
1911 } else {
1912 None
1913 };
1914
1915 let stream = device.cuda_stream().cu_stream();
1916
1917 match dtype {
1918 DType::F16 => {
1919 let output = device.alloc_zeros::<f16>(n_elements)?;
1920
1921 let logits_slice = logits_cuda.as_cuda_slice::<f16>()?;
1922 let sinks_slice = sinks_cuda.as_cuda_slice::<f16>()?;
1923
1924 let (logits_ptr, _l_guard) = slice_ptr(logits_slice, 0);
1925 let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, 0);
1926 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1927
1928 let mask_ptr = if let Some(ref m) = mask {
1929 let m_storage = m.storage_and_layout().0;
1930 let m_cuda = match &*m_storage {
1931 candle_core::Storage::Cuda(s) => s,
1932 _ => candle_core::bail!("Expected CUDA storage for mask"),
1933 };
1934 let m_slice = m_cuda.as_cuda_slice::<f16>()?;
1935 let (ptr, _guard) = slice_ptr(m_slice, 0);
1936 ptr as *const c_void
1937 } else {
1938 std::ptr::null()
1939 };
1940
1941 unsafe {
1942 ffi::softmax_with_sinks_f16(
1943 logits_ptr as *const c_void,
1944 sinks_ptr as *const c_void,
1945 mask_ptr,
1946 out_ptr as *mut c_void,
1947 batch_size as i32,
1948 num_heads as i32,
1949 q_len as i32,
1950 k_len as i32,
1951 1.0, stream,
1953 );
1954 }
1955
1956 drop(_o_guard);
1957 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1958 Ok(Tensor::from((
1959 candle_core::Storage::Cuda(out_storage),
1960 logits.shape().clone(),
1961 )))
1962 }
1963 DType::BF16 => {
1964 let output = device.alloc_zeros::<bf16>(n_elements)?;
1965
1966 let logits_slice = logits_cuda.as_cuda_slice::<bf16>()?;
1967 let sinks_slice = sinks_cuda.as_cuda_slice::<bf16>()?;
1968
1969 let (logits_ptr, _l_guard) = slice_ptr(logits_slice, 0);
1970 let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, 0);
1971 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1972
1973 let mask_ptr = if let Some(ref m) = mask {
1974 let m_storage = m.storage_and_layout().0;
1975 let m_cuda = match &*m_storage {
1976 candle_core::Storage::Cuda(s) => s,
1977 _ => candle_core::bail!("Expected CUDA storage for mask"),
1978 };
1979 let m_slice = m_cuda.as_cuda_slice::<bf16>()?;
1980 let (ptr, _guard) = slice_ptr(m_slice, 0);
1981 ptr as *const c_void
1982 } else {
1983 std::ptr::null()
1984 };
1985
1986 unsafe {
1987 ffi::softmax_with_sinks_bf16(
1988 logits_ptr as *const c_void,
1989 sinks_ptr as *const c_void,
1990 mask_ptr,
1991 out_ptr as *mut c_void,
1992 batch_size as i32,
1993 num_heads as i32,
1994 q_len as i32,
1995 k_len as i32,
1996 1.0,
1997 stream,
1998 );
1999 }
2000
2001 drop(_o_guard);
2002 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2003 Ok(Tensor::from((
2004 candle_core::Storage::Cuda(out_storage),
2005 logits.shape().clone(),
2006 )))
2007 }
2008 DType::F32 => {
2009 let output = device.alloc_zeros::<f32>(n_elements)?;
2010
2011 let logits_slice = logits_cuda.as_cuda_slice::<f32>()?;
2012 let sinks_slice = sinks_cuda.as_cuda_slice::<f32>()?;
2013
2014 let (logits_ptr, _l_guard) = slice_ptr(logits_slice, 0);
2015 let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, 0);
2016 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2017
2018 let mask_ptr = if let Some(ref m) = mask {
2019 let m_storage = m.storage_and_layout().0;
2020 let m_cuda = match &*m_storage {
2021 candle_core::Storage::Cuda(s) => s,
2022 _ => candle_core::bail!("Expected CUDA storage for mask"),
2023 };
2024 let m_slice = m_cuda.as_cuda_slice::<f32>()?;
2025 let (ptr, _guard) = slice_ptr(m_slice, 0);
2026 ptr as *const c_void
2027 } else {
2028 std::ptr::null()
2029 };
2030
2031 unsafe {
2032 ffi::softmax_with_sinks_f32(
2033 logits_ptr as *const c_void,
2034 sinks_ptr as *const c_void,
2035 mask_ptr,
2036 out_ptr as *mut c_void,
2037 batch_size as i32,
2038 num_heads as i32,
2039 q_len as i32,
2040 k_len as i32,
2041 1.0,
2042 stream,
2043 );
2044 }
2045
2046 drop(_o_guard);
2047 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2048 Ok(Tensor::from((
2049 candle_core::Storage::Cuda(out_storage),
2050 logits.shape().clone(),
2051 )))
2052 }
2053 _ => candle_core::bail!("softmax_with_sinks: unsupported dtype {:?}", dtype),
2054 }
2055}
2056
2057mod tests {
2058 #[test]
2059 fn test_cumsum_exclusive_forward_cpu() {
2060 use crate::utils::ops::CumSumOp;
2061 use candle_core::Tensor;
2062 let device = candle_core::Device::Cpu;
2063 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2064 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
2065 assert_eq!(b, [0, 1, 3, 6]);
2066 }
2067
2068 #[test]
2069 fn test_cumsum_inclusive_forward_cpu() {
2070 use crate::utils::ops::CumSumOp;
2071 use candle_core::Tensor;
2072 let device = candle_core::Device::Cpu;
2073 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2074 let b = a
2075 .fast_cumsum_config(0, true, false)
2076 .unwrap()
2077 .to_vec1::<i64>()
2078 .unwrap();
2079 assert_eq!(b, [1, 3, 6, 10]);
2080 }
2081
2082 #[test]
2083 fn test_cumsum_exclusive_reverse_cpu() {
2084 use crate::utils::ops::CumSumOp;
2085 use candle_core::Tensor;
2086 let device = candle_core::Device::Cpu;
2087 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2088 let b = a
2089 .fast_cumsum_config(0, false, true)
2090 .unwrap()
2091 .to_vec1::<i64>()
2092 .unwrap();
2093 assert_eq!(b, [9, 7, 4, 0]);
2094 }
2095
2096 #[test]
2097 fn test_cumsum_inclusive_reverse_cpu() {
2098 use crate::utils::ops::CumSumOp;
2099 use candle_core::Tensor;
2100 let device = candle_core::Device::Cpu;
2101 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2102 let b = a
2103 .fast_cumsum_config(0, true, true)
2104 .unwrap()
2105 .to_vec1::<i64>()
2106 .unwrap();
2107 assert_eq!(b, [10, 9, 7, 4]);
2108 }
2109
2110 #[cfg(feature = "metal")]
2111 #[test]
2112 fn test_cumsum_exclusive_forward_metal() {
2113 use crate::utils::ops::CumSumOp;
2114 use candle_core::Tensor;
2115 let device = candle_core::Device::new_metal(0).unwrap();
2116 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2117 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
2118 assert_eq!(b, [0, 1, 3, 6]);
2119 }
2120
2121 #[cfg(feature = "metal")]
2122 #[test]
2123 fn test_cumsum_inclusive_forward_metal() {
2124 use crate::utils::ops::CumSumOp;
2125 use candle_core::Tensor;
2126 let device = candle_core::Device::new_metal(0).unwrap();
2127 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2128 let b = a
2129 .fast_cumsum_config(0, true, false)
2130 .unwrap()
2131 .to_vec1::<i64>()
2132 .unwrap();
2133 assert_eq!(b, [1, 3, 6, 10]);
2134 }
2135
2136 #[cfg(feature = "metal")]
2137 #[test]
2138 fn test_cumsum_exclusive_reverse_metal() {
2139 use crate::utils::ops::CumSumOp;
2140 use candle_core::Tensor;
2141 let device = candle_core::Device::new_metal(0).unwrap();
2142 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2143 let b = a
2144 .fast_cumsum_config(0, false, true)
2145 .unwrap()
2146 .to_vec1::<i64>()
2147 .unwrap();
2148 assert_eq!(b, [9, 7, 4, 0]);
2149 }
2150
2151 #[cfg(feature = "metal")]
2152 #[test]
2153 fn test_cumsum_inclusive_reverse_metal() {
2154 use crate::utils::ops::CumSumOp;
2155 use candle_core::Tensor;
2156 let device = candle_core::Device::new_metal(0).unwrap();
2157 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2158 let b = a
2159 .fast_cumsum_config(0, true, true)
2160 .unwrap()
2161 .to_vec1::<i64>()
2162 .unwrap();
2163 assert_eq!(b, [10, 9, 7, 4]);
2164 }
2165
2166 #[test]
2167 fn test_nonzero_cpu() {
2168 use crate::utils::ops::NonZeroOp;
2169 use candle_core::Tensor;
2170 let device = candle_core::Device::Cpu;
2171 let a = Tensor::from_vec(
2172 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
2173 &[2, 4],
2174 &device,
2175 )
2176 .unwrap();
2177 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
2178 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
2179 }
2180
2181 #[cfg(feature = "cuda")]
2182 #[test]
2183 fn test_nonzero_cuda() {
2184 use crate::utils::ops::NonZeroOp;
2185 use candle_core::Tensor;
2186 let device = candle_core::Device::new_cuda(0).unwrap();
2187 let a = Tensor::from_vec(
2188 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
2189 &[2, 4],
2190 &device,
2191 )
2192 .unwrap();
2193 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
2194 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
2195 }
2196
2197 #[test]
2198 fn test_bitwise_and_cpu() {
2199 use crate::utils::ops::BitWiseOp;
2200 use candle_core::Tensor;
2201 let device = candle_core::Device::Cpu;
2202 let a =
2203 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2204 let b =
2205 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2206 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
2207 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
2208 }
2209
2210 #[cfg(feature = "cuda")]
2211 #[test]
2212 fn test_bitwise_and_cuda() {
2213 use crate::utils::ops::BitWiseOp;
2214 use candle_core::Tensor;
2215 let device = candle_core::Device::new_cuda(0).unwrap();
2216 let a =
2217 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2218 let b =
2219 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
2220 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
2221 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
2222 }
2223
2224 #[test]
2225 fn test_bitwise_or_cpu() {
2226 use crate::utils::ops::BitWiseOp;
2227 use candle_core::Tensor;
2228 let device = candle_core::Device::Cpu;
2229 let a =
2230 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2231 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
2232 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
2233 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
2234 }
2235
2236 #[cfg(feature = "cuda")]
2237 #[test]
2238 fn test_bitwise_or_cuda() {
2239 use crate::utils::ops::BitWiseOp;
2240 use candle_core::Tensor;
2241 let device = candle_core::Device::new_cuda(0).unwrap();
2242 let a =
2243 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2244 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
2245 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
2246 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
2247 }
2248
2249 #[test]
2250 fn test_bitwise_xor_cpu() {
2251 use crate::utils::ops::BitWiseOp;
2252 use candle_core::Tensor;
2253 let device = candle_core::Device::Cpu;
2254 let a =
2255 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2256 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
2257 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
2258 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
2259 }
2260
2261 #[cfg(feature = "cuda")]
2262 #[test]
2263 fn test_bitwise_xor_cuda() {
2264 use crate::utils::ops::BitWiseOp;
2265 use candle_core::Tensor;
2266 let device = candle_core::Device::new_cuda(0).unwrap();
2267 let a =
2268 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
2269 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
2270 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
2271 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
2272 }
2273
2274 #[test]
2275 fn test_nonzero_and() {
2276 use crate::utils::ops::{BitWiseOp, NonZeroOp};
2277 use candle_core::{Device, Tensor};
2278
2279 let input1 = Tensor::from_vec(
2280 vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
2281 (10,),
2282 &Device::Cpu,
2283 )
2284 .unwrap();
2285 let input2 = Tensor::from_vec(
2286 vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
2287 (10,),
2288 &Device::Cpu,
2289 )
2290 .unwrap();
2291 let input = Tensor::stack(&[input1, input2], 0).unwrap();
2292
2293 let lt = input.lt(0.0).unwrap();
2294 let gt = input.gt(-10.0).unwrap();
2295 let res = lt
2296 .bitwise_and(>)
2297 .unwrap()
2298 .nonzero()
2299 .unwrap()
2300 .to_vec2::<u32>()
2301 .unwrap();
2302
2303 assert_eq!(
2304 res,
2305 [
2306 [0, 3],
2307 [0, 4],
2308 [0, 5],
2309 [0, 6],
2310 [1, 0],
2311 [1, 3],
2312 [1, 5],
2313 [1, 6]
2314 ]
2315 );
2316 }
2317
2318 #[cfg(feature = "cuda")]
2319 #[test]
2320 fn nonzero_and_cuda() {
2321 use crate::utils::ops::{BitWiseOp, NonZeroOp};
2322 use candle_core::{Device, Tensor};
2323
2324 let device = Device::new_cuda(0).unwrap();
2325 let input1 =
2326 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
2327 let input2 =
2328 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
2329 let input = Tensor::stack(&[input1, input2], 0).unwrap();
2330
2331 let lt = input.lt(0.0).unwrap();
2332 let gt = input.gt(-10.0).unwrap();
2333 let res = lt
2334 .bitwise_and(>)
2335 .unwrap()
2336 .nonzero()
2337 .unwrap()
2338 .to_vec2::<u32>()
2339 .unwrap();
2340
2341 assert_eq!(
2342 res,
2343 [
2344 [0, 3],
2345 [0, 4],
2346 [0, 5],
2347 [0, 6],
2348 [1, 0],
2349 [1, 3],
2350 [1, 5],
2351 [1, 6]
2352 ]
2353 );
2354 }
2355
2356 #[test]
2357 fn test_bitpack_8bit_cpu() {
2358 use crate::HqqBits;
2359 use candle_core::{Device, Tensor};
2360 let bits = HqqBits::Eight;
2361 let device = Device::Cpu;
2362 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
2363 let c = bits.bitpack_type()(wq.clone())
2364 .unwrap()
2365 .to_vec2::<u8>()
2366 .unwrap();
2367 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
2368 }
2369
2370 #[cfg(feature = "cuda")]
2371 #[test]
2372 fn test_bitpack_8bit_cuda() {
2373 use crate::HqqBits;
2374 use candle_core::DType;
2375 use candle_core::{Device, Tensor};
2376 let bits = HqqBits::Eight;
2377 let device = Device::new_cuda(0).unwrap();
2378 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
2379 let c = bits.bitpack_type()(wq.clone())
2380 .unwrap()
2381 .to_dtype(DType::U8)
2382 .unwrap()
2383 .to_vec2::<u8>()
2384 .unwrap();
2385 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
2386 }
2387
2388 #[cfg(feature = "metal")]
2389 #[test]
2390 fn test_bitpack_8bit_metal() {
2391 use crate::HqqBits;
2392 use candle_core::{Device, Tensor};
2393 let bits = HqqBits::Eight;
2394 let device = Device::new_metal(0).unwrap();
2395 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
2396 let c = bits.bitpack_type()(wq.clone())
2397 .unwrap()
2398 .to_vec2::<u8>()
2399 .unwrap();
2400 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
2401 }
2402
2403 #[test]
2404 fn test_bitpack_4bit() {
2405 use crate::HqqBits;
2406 use candle_core::{Device, Tensor};
2407 let bits = HqqBits::Four;
2408 let device = Device::Cpu;
2409 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
2410 let c = bits.bitpack_type()(wq.clone())
2411 .unwrap()
2412 .to_vec2::<u8>()
2413 .unwrap();
2414 assert_eq!(c, [[19, 36]]);
2415 }
2416
2417 #[cfg(feature = "cuda")]
2418 #[test]
2419 fn test_bitpack_4bit_cuda() {
2420 use crate::HqqBits;
2421 use candle_core::{Device, Tensor};
2422 let bits = HqqBits::Four;
2423 let device = Device::new_cuda(0).unwrap();
2424 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
2425 let c = bits.bitpack_type()(wq.clone())
2426 .unwrap()
2427 .to_vec2::<u8>()
2428 .unwrap();
2429 assert_eq!(c, [[19, 36]]);
2430 }
2431
2432 #[cfg(feature = "metal")]
2433 #[test]
2434 fn test_bitpack_4bit_metal() {
2435 use crate::HqqBits;
2436 use candle_core::{Device, Tensor};
2437 let bits = HqqBits::Four;
2438 let device = Device::new_metal(0).unwrap();
2439 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
2440 let c = bits.bitpack_type()(wq.clone())
2441 .unwrap()
2442 .to_vec2::<u8>()
2443 .unwrap();
2444 assert_eq!(c, [[19, 36]]);
2445 }
2446 #[cfg(feature = "metal")]
2448 #[test]
2449 fn test_sort_and_argsort_vector_metal() {
2450 use crate::utils::ops::SortOp;
2451 use candle_core::Tensor;
2452
2453 let device = candle_core::Device::new_metal(0).unwrap();
2454 let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
2455
2456 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
2458 assert_eq!(sorted, [1, 2, 3, 4]);
2459
2460 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
2462 assert_eq!(idx, [1, 3, 0, 2]);
2463 }
2464
2465 #[cfg(feature = "metal")]
2466 #[test]
2467 fn test_sort_and_argsort_matrix_axis1_metal() {
2468 use crate::utils::ops::SortOp;
2469 use candle_core::Tensor;
2470
2471 let device = candle_core::Device::new_metal(0).unwrap();
2472 let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
2476
2477 let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
2479 assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
2480
2481 let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
2483 assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
2484 }
2485
2486 #[cfg(feature = "metal")]
2488 #[test]
2489 fn test_sort_and_argsort_vector_2048_metal() {
2490 use crate::utils::ops::SortOp;
2491 use candle_core::Tensor;
2492
2493 const N: usize = 4096;
2494
2495 let device = candle_core::Device::new_metal(0).expect("Metal device");
2496
2497 let vals: Vec<i32> = (0..N as i32).rev().collect();
2499 let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
2500
2501 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
2503 let expected: Vec<i32> = (0..N as i32).collect();
2504 assert_eq!(sorted, expected);
2505
2506 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
2508 for (i, &v) in idx.iter().enumerate() {
2510 assert_eq!(v as usize, N - 1 - i);
2511 }
2512 }
2513}