mistralrs_quant/utils/
mod.rs1#[cfg(feature = "cuda")]
2mod ffi;
3pub(crate) mod isq;
4mod ops;
5
6mod uqff;
7
8pub use ops::{BitWiseOp, LeftshiftOp};
9pub use uqff::UQFF_QUANT_TYPE_OFFSET;
10pub(crate) use uqff::{
11 deserialize_tensor, fake_deserialize_tensor, read_dtype, serialize_tensor,
12 version_is_compatible, write_dtype, UQFF_VERSION,
13};
14
15#[cfg(feature = "cuda")]
16use candle_core::{
17 cuda::{cudarc::driver::DevicePtr, CudaDType},
18 CudaDevice, Device, Storage, Tensor, WithDType,
19};
20
21#[cfg(feature = "cuda")]
22pub(crate) fn get_cuda_slice<T: WithDType + CudaDType>(
23 x: &Tensor,
24) -> candle_core::Result<*const T> {
25 let offset = x.layout().start_offset();
26 match &*x.storage_and_layout().0 {
27 Storage::Cuda(a_storage) => {
28 Ok(*a_storage.as_cuda_slice::<T>()?.slice(offset..).device_ptr() as *const T)
29 }
30 _ => candle_core::bail!("Expected CUDA storage."),
31 }
32}
33
34#[cfg(feature = "cuda")]
35pub(crate) fn get_cuda_device(x: &Tensor) -> candle_core::Result<&CudaDevice> {
36 match x.device() {
37 Device::Cuda(dev) => Ok(dev),
38 _ => candle_core::bail!("Expected CUDA device"),
39 }
40}