mistralrs_quant/utils/
ops.rs

1use candle_core::{
2    backend::BackendStorage, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout, Result, Shape,
3    Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7use std::{
8    fmt::Display,
9    ops::{BitAnd, BitOr, BitXor, Shl},
10};
11
12#[cfg(feature = "cuda")]
13use crate::utils::ffi;
14#[cfg(feature = "cuda")]
15use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr};
16#[cfg(feature = "cuda")]
17use std::ffi::c_void;
18
19struct Leftshift(usize);
20
21impl Leftshift {
22    fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
23        let offset = T::from_f64(self.0 as f64);
24        vs.into_par_iter().map(|v| *v << offset).collect()
25    }
26}
27
28impl CustomOp1 for Leftshift {
29    fn name(&self) -> &'static str {
30        "left"
31    }
32
33    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
34        match s1 {
35            CpuStorage::U8(vs1) => {
36                let vs1 = match l1.contiguous_offsets() {
37                    Some((a, b)) => &vs1[a..b],
38                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
39                };
40                let result = self.leftshift(vs1);
41                let result = CpuStorage::U8(result);
42                Ok((result, l1.shape().clone()))
43            }
44            CpuStorage::I16(vs1) => {
45                let vs1 = match l1.contiguous_offsets() {
46                    Some((a, b)) => &vs1[a..b],
47                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
48                };
49                let result = self.leftshift(vs1);
50                let result = CpuStorage::I16(result);
51                Ok((result, l1.shape().clone()))
52            }
53            CpuStorage::U32(vs1) => {
54                let vs1 = match l1.contiguous_offsets() {
55                    Some((a, b)) => &vs1[a..b],
56                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
57                };
58                let result = self.leftshift(vs1);
59                let result = CpuStorage::U32(result);
60                Ok((result, l1.shape().clone()))
61            }
62            CpuStorage::I64(vs1) => {
63                let vs1 = match l1.contiguous_offsets() {
64                    Some((a, b)) => &vs1[a..b],
65                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
66                };
67                let result = self.leftshift(vs1);
68                let result = CpuStorage::I64(result);
69                Ok((result, l1.shape().clone()))
70            }
71            CpuStorage::I32(vs1) => {
72                let vs1 = match l1.contiguous_offsets() {
73                    Some((a, b)) => &vs1[a..b],
74                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
75                };
76                let result = self.leftshift(vs1);
77                let result = CpuStorage::I32(result);
78                Ok((result, l1.shape().clone()))
79            }
80            CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshifr")),
81            CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")),
82            CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")),
83            CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshifr")),
84            CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshifr")),
85        }
86    }
87    #[cfg(feature = "cuda")]
88    fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
89        if !l1.is_contiguous() {
90            candle_core::bail!("Input tensor s1 must be contiguous");
91        }
92        let dev = s1.device().clone();
93        let (d_in1_ptr, elem_count) = match s1.dtype() {
94            DType::U8 => {
95                let d_in1_ptr = *s1
96                    .as_cuda_slice::<u8>()?
97                    .slice(l1.start_offset()..)
98                    .device_ptr() as *const c_void;
99                let elem_count = l1.shape().elem_count();
100                (d_in1_ptr, elem_count)
101            }
102            DType::I16 => {
103                return Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshift"));
104            }
105            DType::U32 => {
106                return Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshift"));
107            }
108            DType::I64 => {
109                return Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshift"));
110            }
111            DType::I32 => {
112                let d_in1_ptr = *s1
113                    .as_cuda_slice::<i32>()?
114                    .slice(l1.start_offset()..)
115                    .device_ptr() as *const c_void;
116                let elem_count = l1.shape().elem_count();
117                (d_in1_ptr, elem_count)
118            }
119            DType::BF16 => {
120                return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift"));
121            }
122            DType::F16 => {
123                return Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshift"));
124            }
125            DType::F32 => {
126                return Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshift"));
127            }
128            DType::F64 => {
129                return Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshift"));
130            }
131            DType::F8E4M3 => {
132                return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshift"));
133            }
134        };
135        let dst = match s1.dtype() {
136            DType::U8 => {
137                let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
138                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
139                unsafe {
140                    ffi::leftshift_u8(
141                        d_in1_ptr,
142                        d_out_ptr,
143                        u32::try_from(elem_count)?,
144                        self.0 as i32,
145                    )
146                };
147                CudaStorage::wrap_cuda_slice(d_out, dev)
148            }
149            DType::I32 => {
150                let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
151                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
152                unsafe {
153                    ffi::leftshift_i32(
154                        d_in1_ptr,
155                        d_out_ptr,
156                        u32::try_from(elem_count)?,
157                        self.0 as i32,
158                    )
159                };
160                CudaStorage::wrap_cuda_slice(d_out, dev)
161            }
162            _ => unreachable!(),
163        };
164        Ok((dst, l1.shape().clone()))
165    }
166    #[cfg(feature = "metal")]
167    fn metal_fwd(
168        &self,
169        s1: &candle_core::MetalStorage,
170        l1: &Layout,
171    ) -> Result<(candle_core::MetalStorage, Shape)> {
172        if !l1.is_contiguous() {
173            candle_core::bail!("Input tensor s1 must be contiguous");
174        }
175
176        let command_buffer = s1.device().command_buffer()?;
177        command_buffer.set_label("bitwise-leftshift");
178
179        let device = s1.device();
180
181        let out_shape = l1.shape().clone();
182
183        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
184
185        crate::metal_kernels::call_bitwise_leftshift(
186            device.device(),
187            &command_buffer,
188            &crate::metal_kernels::Kernels::new(),
189            s1.dtype(),
190            s1.buffer(),
191            l1.start_offset(),
192            self.0 as u32,
193            out_shape.elem_count(),
194            &output,
195        )
196        .map_err(candle_core::Error::wrap)?;
197
198        let newstorage = candle_core::MetalStorage::new(
199            output,
200            device.clone(),
201            out_shape.elem_count(),
202            s1.dtype(),
203        );
204        Ok((newstorage, out_shape))
205    }
206}
207
208#[allow(dead_code)]
209pub trait LeftshiftOp {
210    fn leftshift(&self, n: usize) -> Result<Tensor>;
211}
212
213impl LeftshiftOp for Tensor {
214    fn leftshift(&self, n: usize) -> Result<Tensor> {
215        self.apply_op1_no_bwd(&Leftshift(n))
216    }
217}
218
219pub enum BitWiseOpEnum {
220    And,
221    Or,
222    Xor,
223}
224
225impl Display for BitWiseOpEnum {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        match self {
228            BitWiseOpEnum::And => write!(f, "And"),
229            BitWiseOpEnum::Or => write!(f, "Or"),
230            BitWiseOpEnum::Xor => write!(f, "Xor"),
231        }
232    }
233}
234
235struct BitWise {
236    pub op: BitWiseOpEnum,
237}
238
239impl BitWise {
240    pub fn new(op: BitWiseOpEnum) -> Self {
241        Self { op }
242    }
243
244    fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
245        &self,
246        vs1: &[T],
247        vs2: &[T],
248    ) -> Vec<T> {
249        vs1.into_par_iter()
250            .zip_eq(vs2)
251            .map(|(v1, v2)| match self.op {
252                BitWiseOpEnum::And => *v1 & *v2,
253                BitWiseOpEnum::Or => *v1 | *v2,
254                BitWiseOpEnum::Xor => *v1 ^ *v2,
255            })
256            .collect()
257    }
258}
259
260impl CustomOp2 for BitWise {
261    fn name(&self) -> &'static str {
262        "bitwise"
263    }
264
265    fn cpu_fwd(
266        &self,
267        s1: &CpuStorage,
268        l1: &Layout,
269        s2: &CpuStorage,
270        l2: &Layout,
271    ) -> Result<(CpuStorage, Shape)> {
272        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
273            return Err(Error::ShapeMismatchBinaryOp {
274                lhs: l1.shape().clone(),
275                rhs: l2.shape().clone(),
276                op: "bitwise-op",
277            });
278        }
279        if s1.dtype() != s2.dtype() {
280            return Err(Error::DTypeMismatchBinaryOp {
281                lhs: s1.dtype(),
282                rhs: s2.dtype(),
283                op: "bitwise-op",
284            });
285        }
286        if !l1.is_contiguous() {
287            candle_core::bail!("Input tensor s1 must be contiguous");
288        }
289        if !l2.is_contiguous() {
290            candle_core::bail!("Input tensor s2 must be contiguous");
291        }
292
293        match s1 {
294            CpuStorage::U8(vs1) => {
295                let vs2 = s2.as_slice::<u8>().unwrap();
296                let vs1 = match l1.contiguous_offsets() {
297                    Some((a, b)) => &vs1[a..b],
298                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
299                };
300                let vs2 = match l2.contiguous_offsets() {
301                    Some((a, b)) => &vs2[a..b],
302                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
303                };
304                let result = self.bitwise(vs1, vs2);
305                let result = CpuStorage::U8(result);
306                Ok((result, l1.shape().clone()))
307            }
308            CpuStorage::U32(vs1) => {
309                let vs2 = s2.as_slice::<u32>().unwrap();
310                let vs1 = match l1.contiguous_offsets() {
311                    Some((a, b)) => &vs1[a..b],
312                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
313                };
314                let vs2 = match l2.contiguous_offsets() {
315                    Some((a, b)) => &vs2[a..b],
316                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
317                };
318                let result = self.bitwise(vs1, vs2);
319                let result = CpuStorage::U32(result);
320                Ok((result, l1.shape().clone()))
321            }
322            CpuStorage::I64(vs1) => {
323                let vs2 = s2.as_slice::<i64>().unwrap();
324                let vs1 = match l1.contiguous_offsets() {
325                    Some((a, b)) => &vs1[a..b],
326                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
327                };
328                let vs2 = match l2.contiguous_offsets() {
329                    Some((a, b)) => &vs2[a..b],
330                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
331                };
332                let result = self.bitwise(vs1, vs2);
333                let result = CpuStorage::I64(result);
334                Ok((result, l1.shape().clone()))
335            }
336            CpuStorage::I16(vs1) => {
337                let vs2 = s2.as_slice::<i16>().unwrap();
338                let vs1 = match l1.contiguous_offsets() {
339                    Some((a, b)) => &vs1[a..b],
340                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
341                };
342                let vs2 = match l2.contiguous_offsets() {
343                    Some((a, b)) => &vs2[a..b],
344                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
345                };
346                let result = self.bitwise(vs1, vs2);
347                let result = CpuStorage::I16(result);
348                Ok((result, l1.shape().clone()))
349            }
350            CpuStorage::I32(vs1) => {
351                let vs2 = s2.as_slice::<i32>().unwrap();
352                let vs1 = match l1.contiguous_offsets() {
353                    Some((a, b)) => &vs1[a..b],
354                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
355                };
356                let vs2 = match l2.contiguous_offsets() {
357                    Some((a, b)) => &vs2[a..b],
358                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
359                };
360                let result = self.bitwise(vs1, vs2);
361                let result = CpuStorage::I32(result);
362                Ok((result, l1.shape().clone()))
363            }
364            CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")),
365            CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")),
366            CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")),
367            CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")),
368            CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")),
369        }
370    }
371
372    #[cfg(feature = "cuda")]
373    fn cuda_fwd(
374        &self,
375        s1: &CudaStorage,
376        l1: &Layout,
377        s2: &CudaStorage,
378        l2: &Layout,
379    ) -> Result<(CudaStorage, Shape)> {
380        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
381            return Err(Error::ShapeMismatchBinaryOp {
382                lhs: l1.shape().clone(),
383                rhs: l2.shape().clone(),
384                op: "bitwise-op",
385            });
386        }
387        if s1.dtype() != s2.dtype() {
388            return Err(Error::DTypeMismatchBinaryOp {
389                lhs: s1.dtype(),
390                rhs: s2.dtype(),
391                op: "bitwise-op",
392            });
393        }
394        if !l1.is_contiguous() {
395            candle_core::bail!("Input tensor s1 must be contiguous");
396        }
397        if !l2.is_contiguous() {
398            candle_core::bail!("Input tensor s2 must be contiguous");
399        }
400
401        let dev = s1.device().clone();
402        let (d_in1_ptr, d_in2_ptr, elem_count) = match s1.dtype() {
403            DType::U8 => {
404                let d_in1_ptr = *s1
405                    .as_cuda_slice::<u8>()?
406                    .slice(l1.start_offset()..)
407                    .device_ptr() as *const c_void;
408                let d_in2_ptr = *s2
409                    .as_cuda_slice::<u8>()?
410                    .slice(l2.start_offset()..)
411                    .device_ptr() as *const c_void;
412                let elem_count = l1.shape().elem_count();
413                (d_in1_ptr, d_in2_ptr, elem_count)
414            }
415            DType::U32 => {
416                let d_in1_ptr = *s1
417                    .as_cuda_slice::<u32>()?
418                    .slice(l1.start_offset()..)
419                    .device_ptr() as *const c_void;
420                let d_in2_ptr = *s2
421                    .as_cuda_slice::<u32>()?
422                    .slice(l2.start_offset()..)
423                    .device_ptr() as *const c_void;
424                let elem_count = l1.shape().elem_count();
425                (d_in1_ptr, d_in2_ptr, elem_count)
426            }
427            DType::I64 => {
428                let d_in1_ptr = *s1
429                    .as_cuda_slice::<i64>()?
430                    .slice(l1.start_offset()..)
431                    .device_ptr() as *const c_void;
432                let d_in2_ptr = *s2
433                    .as_cuda_slice::<i64>()?
434                    .slice(l2.start_offset()..)
435                    .device_ptr() as *const c_void;
436                let elem_count = l1.shape().elem_count();
437                (d_in1_ptr, d_in2_ptr, elem_count)
438            }
439            DType::I32 => {
440                let d_in1_ptr = *s1
441                    .as_cuda_slice::<i32>()?
442                    .slice(l1.start_offset()..)
443                    .device_ptr() as *const c_void;
444                let d_in2_ptr = *s2
445                    .as_cuda_slice::<i32>()?
446                    .slice(l2.start_offset()..)
447                    .device_ptr() as *const c_void;
448                let elem_count = l1.shape().elem_count();
449                (d_in1_ptr, d_in2_ptr, elem_count)
450            }
451            DType::I16 => {
452                let d_in1_ptr = *s1
453                    .as_cuda_slice::<i16>()?
454                    .slice(l1.start_offset()..)
455                    .device_ptr() as *const c_void;
456                let d_in2_ptr = *s2
457                    .as_cuda_slice::<i16>()?
458                    .slice(l2.start_offset()..)
459                    .device_ptr() as *const c_void;
460                let elem_count = l1.shape().elem_count();
461                (d_in1_ptr, d_in2_ptr, elem_count)
462            }
463            DType::BF16 => {
464                return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise"));
465            }
466            DType::F16 => {
467                return Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise"));
468            }
469            DType::F32 => {
470                return Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise"));
471            }
472            DType::F64 => {
473                return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise"));
474            }
475            DType::F8E4M3 => {
476                return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise"));
477            }
478        };
479        let dst = match s1.dtype() {
480            DType::U8 => {
481                let d_out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
482                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
483                unsafe {
484                    match self.op {
485                        BitWiseOpEnum::And => ffi::bitwise_and_u8(
486                            d_in1_ptr,
487                            d_in2_ptr,
488                            d_out_ptr,
489                            u32::try_from(elem_count)?,
490                        ),
491                        BitWiseOpEnum::Or => ffi::bitwise_or_u8(
492                            d_in1_ptr,
493                            d_in2_ptr,
494                            d_out_ptr,
495                            u32::try_from(elem_count)?,
496                        ),
497                        BitWiseOpEnum::Xor => ffi::bitwise_xor_u8(
498                            d_in1_ptr,
499                            d_in2_ptr,
500                            d_out_ptr,
501                            u32::try_from(elem_count)?,
502                        ),
503                    }
504                };
505                CudaStorage::wrap_cuda_slice(d_out, dev)
506            }
507            DType::U32 => {
508                let d_out = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
509                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
510                unsafe {
511                    match self.op {
512                        BitWiseOpEnum::And => ffi::bitwise_and_u32(
513                            d_in1_ptr,
514                            d_in2_ptr,
515                            d_out_ptr,
516                            u32::try_from(elem_count)?,
517                        ),
518                        BitWiseOpEnum::Or => ffi::bitwise_or_u32(
519                            d_in1_ptr,
520                            d_in2_ptr,
521                            d_out_ptr,
522                            u32::try_from(elem_count)?,
523                        ),
524                        BitWiseOpEnum::Xor => ffi::bitwise_xor_u32(
525                            d_in1_ptr,
526                            d_in2_ptr,
527                            d_out_ptr,
528                            u32::try_from(elem_count)?,
529                        ),
530                    }
531                };
532                CudaStorage::wrap_cuda_slice(d_out, dev)
533            }
534            DType::I64 => {
535                let d_out = unsafe { dev.alloc::<i64>(elem_count) }.w()?;
536                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
537                unsafe {
538                    match self.op {
539                        BitWiseOpEnum::And => ffi::bitwise_and_i64(
540                            d_in1_ptr,
541                            d_in2_ptr,
542                            d_out_ptr,
543                            u32::try_from(elem_count)?,
544                        ),
545                        BitWiseOpEnum::Or => ffi::bitwise_or_i64(
546                            d_in1_ptr,
547                            d_in2_ptr,
548                            d_out_ptr,
549                            u32::try_from(elem_count)?,
550                        ),
551                        BitWiseOpEnum::Xor => ffi::bitwise_xor_i64(
552                            d_in1_ptr,
553                            d_in2_ptr,
554                            d_out_ptr,
555                            u32::try_from(elem_count)?,
556                        ),
557                    }
558                };
559                CudaStorage::wrap_cuda_slice(d_out, dev)
560            }
561            DType::I32 => {
562                let d_out = unsafe { dev.alloc::<i32>(elem_count) }.w()?;
563                let d_out_ptr = *d_out.device_ptr() as *mut c_void;
564                unsafe {
565                    match self.op {
566                        BitWiseOpEnum::And => ffi::bitwise_and_i32(
567                            d_in1_ptr,
568                            d_in2_ptr,
569                            d_out_ptr,
570                            u32::try_from(elem_count)?,
571                        ),
572                        BitWiseOpEnum::Or => ffi::bitwise_or_i32(
573                            d_in1_ptr,
574                            d_in2_ptr,
575                            d_out_ptr,
576                            u32::try_from(elem_count)?,
577                        ),
578                        BitWiseOpEnum::Xor => ffi::bitwise_xor_i32(
579                            d_in1_ptr,
580                            d_in2_ptr,
581                            d_out_ptr,
582                            u32::try_from(elem_count)?,
583                        ),
584                    }
585                };
586                CudaStorage::wrap_cuda_slice(d_out, dev)
587            }
588            _ => unreachable!(),
589        };
590        Ok((dst, l1.shape().clone()))
591    }
592
593    #[cfg(feature = "metal")]
594    fn metal_fwd(
595        &self,
596        s1: &candle_core::MetalStorage,
597        l1: &Layout,
598        s2: &candle_core::MetalStorage,
599        l2: &Layout,
600    ) -> Result<(candle_core::MetalStorage, Shape)> {
601        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
602            return Err(Error::ShapeMismatchBinaryOp {
603                lhs: l1.shape().clone(),
604                rhs: l2.shape().clone(),
605                op: "bitwise-op",
606            });
607        }
608        if s1.dtype() != s2.dtype() {
609            return Err(Error::DTypeMismatchBinaryOp {
610                lhs: s1.dtype(),
611                rhs: s2.dtype(),
612                op: "bitwise-op",
613            });
614        }
615        if !l1.is_contiguous() {
616            candle_core::bail!("Input tensor s1 must be contiguous");
617        }
618        if !l2.is_contiguous() {
619            candle_core::bail!("Input tensor s2 must be contiguous");
620        }
621
622        let command_buffer = s1.device().command_buffer()?;
623        command_buffer.set_label("bitwise-op");
624
625        let device = s1.device();
626
627        let out_shape = l1.shape().clone();
628
629        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
630
631        match self.op {
632            BitWiseOpEnum::Or => crate::metal_kernels::call_bitwise_or(
633                device.device(),
634                &command_buffer,
635                &crate::metal_kernels::Kernels::new(),
636                s1.dtype(),
637                s1.buffer(),
638                s2.buffer(),
639                l1.start_offset() * s1.dtype().size_in_bytes(),
640                l2.start_offset() * s2.dtype().size_in_bytes(),
641                out_shape.elem_count(),
642                &output,
643            )
644            .map_err(candle_core::Error::wrap)?,
645            BitWiseOpEnum::And => crate::metal_kernels::call_bitwise_and(
646                device.device(),
647                &command_buffer,
648                &crate::metal_kernels::Kernels::new(),
649                s1.dtype(),
650                s1.buffer(),
651                s2.buffer(),
652                l1.start_offset(),
653                l2.start_offset(),
654                out_shape.elem_count(),
655                &output,
656            )
657            .map_err(candle_core::Error::wrap)?,
658            BitWiseOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
659                device.device(),
660                &command_buffer,
661                &crate::metal_kernels::Kernels::new(),
662                s1.dtype(),
663                s1.buffer(),
664                s2.buffer(),
665                l1.start_offset(),
666                l2.start_offset(),
667                out_shape.elem_count(),
668                &output,
669            )
670            .map_err(candle_core::Error::wrap)?,
671        }
672
673        let newstorage = candle_core::MetalStorage::new(
674            output,
675            device.clone(),
676            out_shape.elem_count(),
677            s1.dtype(),
678        );
679        Ok((newstorage, out_shape))
680    }
681}
682
683#[allow(dead_code)]
684pub trait BitWiseOp {
685    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
686    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
687    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
688}
689
690impl BitWiseOp for Tensor {
691    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
692        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::And))
693    }
694
695    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
696        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::Or))
697    }
698
699    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
700        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseOpEnum::Xor))
701    }
702}
703
704struct NonZero;
705
706impl NonZero {
707    // Sequential version
708    fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
709        let n = layout.dims().len();
710        let mut result = Vec::new();
711        let mut indices = vec![0u32; n];
712        for (i, v) in vs.iter().enumerate() {
713            if !v.is_zero() {
714                let mut idx = i;
715                for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
716                    let d = idx % dim;
717                    indices[dim_index] = u32::try_from(d).unwrap();
718                    idx /= dim;
719                }
720                result.extend_from_slice(&indices);
721            }
722        }
723        result
724    }
725}
726
727#[cfg(feature = "cuda")]
728fn count_nonzero_cuda(
729    dtype: candle_core::DType,
730    d_in: *const c_void,
731    n: u32,
732    stream: candle_core::cuda::cudarc::driver::sys::CUstream,
733) -> u32 {
734    unsafe {
735        match dtype {
736            candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
737            candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
738            candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
739            candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
740            candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
741            candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
742            candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
743            candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
744            candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
745            candle_core::DType::F8E4M3 => todo!(),
746        }
747    }
748}
749
750#[allow(clippy::too_many_arguments)]
751#[cfg(feature = "cuda")]
752fn nonzero_cuda(
753    dtype: candle_core::DType,
754    d_in: *const c_void,
755    n: u32,
756    num_nonzero: u32,
757    dims: *const c_void,
758    num_dims: u32,
759    d_out: *mut c_void,
760    stream: candle_core::cuda::cudarc::driver::sys::CUstream,
761) {
762    unsafe {
763        match dtype {
764            candle_core::DType::U8 => {
765                ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
766            }
767            candle_core::DType::U32 => {
768                ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
769            }
770            candle_core::DType::I64 => {
771                ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
772            }
773            candle_core::DType::I32 => {
774                ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
775            }
776            candle_core::DType::I16 => {
777                ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
778            }
779            candle_core::DType::BF16 => {
780                ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
781            }
782            candle_core::DType::F16 => {
783                ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
784            }
785            candle_core::DType::F32 => {
786                ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
787            }
788            candle_core::DType::F64 => {
789                ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
790            }
791            candle_core::DType::F8E4M3 => todo!(),
792        }
793    }
794}
795
796impl CustomOp1 for NonZero {
797    fn name(&self) -> &'static str {
798        "nonzero"
799    }
800
801    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
802        if !layout.is_contiguous() {
803            return Err(Error::RequiresContiguous { op: "nonzero" });
804        }
805        let result = match storage {
806            candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout),
807            candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout),
808            candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout),
809            candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout),
810            candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout),
811            candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
812            candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
813            candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
814            candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
815            candle_core::CpuStorage::F8E4M3(_vs) => todo!(),
816        };
817        let index_len = layout.dims().len();
818        let result_len = result.len() / index_len;
819        let result = CpuStorage::U32(result);
820        let shape = Shape::from_dims(&[result_len, index_len]);
821        Ok((result, shape))
822    }
823    #[cfg(feature = "cuda")]
824    fn cuda_fwd(
825        &self,
826        storage: &candle_core::CudaStorage,
827        layout: &Layout,
828    ) -> Result<(candle_core::CudaStorage, Shape)> {
829        if !layout.is_contiguous() {
830            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
831        }
832        let dev = storage.device().clone();
833        let d_in = match storage.dtype() {
834            candle_core::DType::U8 => *storage.as_cuda_slice::<u8>()?.device_ptr(),
835            candle_core::DType::U32 => *storage.as_cuda_slice::<u32>()?.device_ptr(),
836            candle_core::DType::I32 => *storage.as_cuda_slice::<i32>()?.device_ptr(),
837            candle_core::DType::I16 => *storage.as_cuda_slice::<i16>()?.device_ptr(),
838            candle_core::DType::I64 => *storage.as_cuda_slice::<i64>()?.device_ptr(),
839            candle_core::DType::BF16 => *storage.as_cuda_slice::<half::bf16>()?.device_ptr(),
840            candle_core::DType::F16 => *storage.as_cuda_slice::<half::f16>()?.device_ptr(),
841            candle_core::DType::F32 => *storage.as_cuda_slice::<f32>()?.device_ptr(),
842            candle_core::DType::F64 => *storage.as_cuda_slice::<f64>()?.device_ptr(),
843            candle_core::DType::F8E4M3 => todo!(),
844        } as *const c_void;
845        let n = layout.shape().elem_count();
846
847        let num_nonzero =
848            count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?, *dev.cu_stream());
849        let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
850            .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
851        let d_out_ptr = *d_out.device_ptr() as *mut c_void;
852        let dims = layout
853            .dims()
854            .iter()
855            .map(|&x| u32::try_from(x).unwrap())
856            .collect::<Vec<u32>>();
857        let d_dims = dev
858            .htod_copy(dims)
859            .map_err(|_| Error::Msg("Failed to copy dims to device".to_string()))?;
860        let d_dims_ptr = *d_dims.device_ptr() as *const c_void;
861        nonzero_cuda(
862            storage.dtype(),
863            d_in,
864            u32::try_from(n)?,
865            num_nonzero,
866            d_dims_ptr,
867            u32::try_from(layout.dims().len())?,
868            d_out_ptr,
869            *dev.cu_stream(),
870        );
871        let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
872        let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev);
873        Ok((dst, shape))
874    }
875}
876
877pub trait NonZeroOp {
878    fn nonzero(&self) -> Result<Tensor>;
879}
880
881impl NonZeroOp for Tensor {
882    #[cfg(feature = "metal")]
883    fn nonzero(&self) -> Result<Tensor> {
884        if !self.is_contiguous() {
885            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
886        }
887        let original_device = self.device();
888        self.to_device(&candle_core::Device::Cpu)?
889            .apply_op1_no_bwd(&NonZero)?
890            .to_device(original_device)
891    }
892
893    #[cfg(not(feature = "metal"))]
894    fn nonzero(&self) -> Result<Tensor> {
895        if !self.is_contiguous() {
896            return Err(candle_core::Error::RequiresContiguous { op: "nonzero" });
897        }
898        self.apply_op1_no_bwd(&NonZero)
899    }
900}
901
902mod tests {
903    #[test]
904    fn test_nonzero_cpu() {
905        use crate::utils::ops::NonZeroOp;
906        use candle_core::Tensor;
907        let device = candle_core::Device::Cpu;
908        let a = Tensor::from_vec(
909            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
910            &[2, 4],
911            &device,
912        )
913        .unwrap();
914        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
915        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
916    }
917
918    #[cfg(feature = "cuda")]
919    #[test]
920    fn test_nonzero_cuda() {
921        use crate::utils::ops::NonZeroOp;
922        use candle_core::Tensor;
923        let device = candle_core::Device::new_cuda(0).unwrap();
924        let a = Tensor::from_vec(
925            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
926            &[2, 4],
927            &device,
928        )
929        .unwrap();
930        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
931        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
932    }
933
934    #[test]
935    fn test_bitwise_and_cpu() {
936        use crate::utils::ops::BitWiseOp;
937        use candle_core::Tensor;
938        let device = candle_core::Device::Cpu;
939        let a =
940            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
941        let b =
942            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
943        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
944        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
945    }
946
947    #[cfg(feature = "cuda")]
948    #[test]
949    fn test_bitwise_and_cuda() {
950        use crate::utils::ops::BitWiseOp;
951        use candle_core::Tensor;
952        let device = candle_core::Device::new_cuda(0).unwrap();
953        let a =
954            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
955        let b =
956            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
957        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
958        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
959    }
960
961    #[test]
962    fn test_bitwise_or_cpu() {
963        use crate::utils::ops::BitWiseOp;
964        use candle_core::Tensor;
965        let device = candle_core::Device::Cpu;
966        let a =
967            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
968        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
969        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
970        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
971    }
972
973    #[cfg(feature = "cuda")]
974    #[test]
975    fn test_bitwise_or_cuda() {
976        use crate::utils::ops::BitWiseOp;
977        use candle_core::Tensor;
978        let device = candle_core::Device::new_cuda(0).unwrap();
979        let a =
980            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
981        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
982        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
983        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
984    }
985
986    #[test]
987    fn test_bitwise_xor_cpu() {
988        use crate::utils::ops::BitWiseOp;
989        use candle_core::Tensor;
990        let device = candle_core::Device::Cpu;
991        let a =
992            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
993        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
994        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
995        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
996    }
997
998    #[cfg(feature = "cuda")]
999    #[test]
1000    fn test_bitwise_xor_cuda() {
1001        use crate::utils::ops::BitWiseOp;
1002        use candle_core::Tensor;
1003        let device = candle_core::Device::new_cuda(0).unwrap();
1004        let a =
1005            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
1006        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
1007        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
1008        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
1009    }
1010
1011    #[test]
1012    fn test_nonzero_and() {
1013        use crate::utils::ops::{BitWiseOp, NonZeroOp};
1014        use candle_core::{Device, Tensor};
1015
1016        let input1 = Tensor::from_vec(
1017            vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
1018            (10,),
1019            &Device::Cpu,
1020        )
1021        .unwrap();
1022        let input2 = Tensor::from_vec(
1023            vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
1024            (10,),
1025            &Device::Cpu,
1026        )
1027        .unwrap();
1028        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1029
1030        let lt = input.lt(0.0).unwrap();
1031        let gt = input.gt(-10.0).unwrap();
1032        let res = lt
1033            .bitwise_and(&gt)
1034            .unwrap()
1035            .nonzero()
1036            .unwrap()
1037            .to_vec2::<u32>()
1038            .unwrap();
1039
1040        assert_eq!(
1041            res,
1042            [
1043                [0, 3],
1044                [0, 4],
1045                [0, 5],
1046                [0, 6],
1047                [1, 0],
1048                [1, 3],
1049                [1, 5],
1050                [1, 6]
1051            ]
1052        );
1053    }
1054
1055    #[cfg(feature = "cuda")]
1056    #[test]
1057    fn nonzero_and_cuda() {
1058        use crate::utils::ops::{BitWiseOp, NonZeroOp};
1059        use candle_core::{Device, Tensor};
1060
1061        let device = Device::new_cuda(0).unwrap();
1062        let input1 =
1063            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1064        let input2 =
1065            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
1066        let input = Tensor::stack(&[input1, input2], 0).unwrap();
1067
1068        let lt = input.lt(0.0).unwrap();
1069        let gt = input.gt(-10.0).unwrap();
1070        let res = lt
1071            .bitwise_and(&gt)
1072            .unwrap()
1073            .nonzero()
1074            .unwrap()
1075            .to_vec2::<u32>()
1076            .unwrap();
1077
1078        assert_eq!(
1079            res,
1080            [
1081                [0, 3],
1082                [0, 4],
1083                [0, 5],
1084                [0, 6],
1085                [1, 0],
1086                [1, 3],
1087                [1, 5],
1088                [1, 6]
1089            ]
1090        );
1091    }
1092
1093    #[test]
1094    fn test_bitpack_8bit_cpu() {
1095        use crate::HqqBits;
1096        use candle_core::{Device, Tensor};
1097        let bits = HqqBits::Eight;
1098        let device = Device::Cpu;
1099        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1100        let c = bits.bitpack_type()(wq.clone())
1101            .unwrap()
1102            .to_vec2::<u8>()
1103            .unwrap();
1104        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1105    }
1106
1107    #[cfg(feature = "cuda")]
1108    #[test]
1109    fn test_bitpack_8bit_cuda() {
1110        use crate::HqqBits;
1111        use candle_core::DType;
1112        use candle_core::{Device, Tensor};
1113        let bits = HqqBits::Eight;
1114        let device = Device::new_cuda(0).unwrap();
1115        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1116        let c = bits.bitpack_type()(wq.clone())
1117            .unwrap()
1118            .to_dtype(DType::U8)
1119            .unwrap()
1120            .to_vec2::<u8>()
1121            .unwrap();
1122        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1123    }
1124
1125    #[cfg(feature = "metal")]
1126    #[test]
1127    fn test_bitpack_8bit_metal() {
1128        use crate::HqqBits;
1129        use candle_core::{Device, Tensor};
1130        let bits = HqqBits::Eight;
1131        let device = Device::new_metal(0).unwrap();
1132        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
1133        let c = bits.bitpack_type()(wq.clone())
1134            .unwrap()
1135            .to_vec2::<u8>()
1136            .unwrap();
1137        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
1138    }
1139
1140    #[test]
1141    fn test_bitpack_4bit() {
1142        use crate::HqqBits;
1143        use candle_core::{Device, Tensor};
1144        let bits = HqqBits::Four;
1145        let device = Device::Cpu;
1146        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1147        let c = bits.bitpack_type()(wq.clone())
1148            .unwrap()
1149            .to_vec2::<u8>()
1150            .unwrap();
1151        assert_eq!(c, [[19, 36]]);
1152    }
1153
1154    #[cfg(feature = "cuda")]
1155    #[test]
1156    fn test_bitpack_4bit_cuda() {
1157        use crate::HqqBits;
1158        use candle_core::{Device, Tensor};
1159        let bits = HqqBits::Four;
1160        let device = Device::new_cuda(0).unwrap();
1161        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1162        let c = bits.bitpack_type()(wq.clone())
1163            .unwrap()
1164            .to_vec2::<u8>()
1165            .unwrap();
1166        assert_eq!(c, [[19, 36]]);
1167    }
1168
1169    #[cfg(feature = "metal")]
1170    #[test]
1171    fn test_bitpack_4bit_metal() {
1172        use crate::HqqBits;
1173        use candle_core::{Device, Tensor};
1174        let bits = HqqBits::Four;
1175        let device = Device::new_metal(0).unwrap();
1176        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
1177        let c = bits.bitpack_type()(wq.clone())
1178            .unwrap()
1179            .to_vec2::<u8>()
1180            .unwrap();
1181        assert_eq!(c, [[19, 36]]);
1182    }
1183}