mistralrs_quant/utils/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#[cfg(feature = "cuda")]
mod ffi;
pub(crate) mod isq;
mod ops;

mod uqff;

pub use ops::{BitWiseOp, LeftshiftOp};
pub(crate) use uqff::{
    deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype,
    HQFF_VERSION,
};

#[cfg(feature = "cuda")]
use candle_core::{
    cuda::{cudarc::driver::DevicePtr, CudaDType},
    CudaDevice, Device, Storage, Tensor, WithDType,
};

#[cfg(feature = "cuda")]
pub(crate) fn get_cuda_slice<T: WithDType + CudaDType>(
    x: &Tensor,
) -> candle_core::Result<*const T> {
    let offset = x.layout().start_offset();
    match &*x.storage_and_layout().0 {
        Storage::Cuda(a_storage) => {
            Ok(*a_storage.as_cuda_slice::<T>()?.slice(offset..).device_ptr() as *const T)
        }
        _ => candle_core::bail!("Expected CUDA storage."),
    }
}

#[cfg(feature = "cuda")]
pub(crate) fn get_cuda_device(x: &Tensor) -> candle_core::Result<&CudaDevice> {
    match x.device() {
        Device::Cuda(dev) => Ok(dev),
        _ => candle_core::bail!("Expected CUDA device"),
    }
}