mistralrs_quant/utils/
mod.rs

1#[cfg(feature = "cuda")]
2mod ffi;
3pub(crate) mod isq;
4pub mod log;
5mod ops;
6mod uqff;
7
8pub use ops::{BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp};
9pub use uqff::UQFF_QUANT_TYPE_OFFSET;
10pub(crate) use uqff::{
11    deserialize_tensor, fake_deserialize_tensor, read_dtype, serialize_tensor,
12    version_is_compatible, write_dtype, UQFF_VERSION,
13};
14
15#[cfg(feature = "cuda")]
16use candle_core::{
17    cuda::cudarc::{
18        self,
19        driver::{CudaSlice, DevicePtr, DeviceRepr},
20    },
21    CudaDevice, Device, Tensor,
22};
23
24#[cfg(feature = "cuda")]
25pub(crate) fn get_cuda_device(x: &Tensor) -> candle_core::Result<&CudaDevice> {
26    match x.device() {
27        Device::Cuda(dev) => Ok(dev),
28        _ => candle_core::bail!("Expected CUDA device"),
29    }
30}
31
32#[cfg(feature = "cuda")]
33pub fn slice_ptr<T: DeviceRepr>(
34    v: &CudaSlice<T>,
35    lo: usize,
36) -> (u64, cudarc::driver::SyncOnDrop<'_>) {
37    let (_, guard) = v.device_ptr(v.stream());
38    let (ptr, _) = v.slice(lo..).device_ptr(v.stream());
39    (ptr, guard)
40}