mistralrs_quant/utils/
mod.rs

1#[cfg(feature = "cuda")]
2mod ffi;
3pub(crate) mod isq;
4mod ops;
5
6mod uqff;
7
8pub use ops::{BitWiseOp, LeftshiftOp};
9pub use uqff::UQFF_QUANT_TYPE_OFFSET;
10pub(crate) use uqff::{
11    deserialize_tensor, fake_deserialize_tensor, read_dtype, serialize_tensor,
12    version_is_compatible, write_dtype, UQFF_VERSION,
13};
14
15#[cfg(feature = "cuda")]
16use candle_core::{
17    cuda::{cudarc::driver::DevicePtr, CudaDType},
18    CudaDevice, Device, Storage, Tensor, WithDType,
19};
20
21#[cfg(feature = "cuda")]
22pub(crate) fn get_cuda_slice<T: WithDType + CudaDType>(
23    x: &Tensor,
24) -> candle_core::Result<*const T> {
25    let offset = x.layout().start_offset();
26    match &*x.storage_and_layout().0 {
27        Storage::Cuda(a_storage) => {
28            Ok(*a_storage.as_cuda_slice::<T>()?.slice(offset..).device_ptr() as *const T)
29        }
30        _ => candle_core::bail!("Expected CUDA storage."),
31    }
32}
33
34#[cfg(feature = "cuda")]
35pub(crate) fn get_cuda_device(x: &Tensor) -> candle_core::Result<&CudaDevice> {
36    match x.device() {
37        Device::Cuda(dev) => Ok(dev),
38        _ => candle_core::bail!("Expected CUDA device"),
39    }
40}