mistralrs_quant/utils/
mod.rs

1#[cfg(feature = "cuda")]
2mod ffi;
3pub(crate) mod isq;
4pub mod log;
5mod ops;
6mod uqff;
7
8#[cfg(feature = "cuda")]
9pub use ops::gptoss_swiglu_fused;
10#[cfg(feature = "cuda")]
11pub use ops::gptoss_swiglu_interleaved;
12#[cfg(feature = "cuda")]
13pub use ops::softmax_with_sinks;
14pub use ops::{BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp};
15pub use uqff::UQFF_QUANT_TYPE_OFFSET;
16pub(crate) use uqff::{
17    deserialize_tensor, fake_deserialize_tensor, read_dtype, serialize_tensor,
18    version_is_compatible, write_dtype, UQFF_VERSION,
19};
20
21#[cfg(feature = "cuda")]
22use candle_core::{
23    cuda::cudarc::{
24        self,
25        driver::{CudaSlice, DevicePtr, DeviceRepr},
26    },
27    CudaDevice, Device, Tensor,
28};
29
30#[cfg(feature = "cuda")]
31pub(crate) fn get_cuda_device(x: &Tensor) -> candle_core::Result<&CudaDevice> {
32    match x.device() {
33        Device::Cuda(dev) => Ok(dev),
34        _ => candle_core::bail!("Expected CUDA device"),
35    }
36}
37
38#[cfg(feature = "cuda")]
39pub fn slice_ptr<T: DeviceRepr>(
40    v: &CudaSlice<T>,
41    lo: usize,
42) -> (u64, cudarc::driver::SyncOnDrop<'_>) {
43    let (_, guard) = v.device_ptr(v.stream());
44    let (ptr, _) = v.slice(lo..).device_ptr(v.stream());
45    (ptr, guard)
46}