mistralrs_quant/utils/
mod.rs1#[cfg(feature = "cuda")]
2mod ffi;
3pub(crate) mod isq;
4pub mod log;
5mod ops;
6mod uqff;
7
8pub use ops::{BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp};
9pub use uqff::UQFF_QUANT_TYPE_OFFSET;
10pub(crate) use uqff::{
11 deserialize_tensor, fake_deserialize_tensor, read_dtype, serialize_tensor,
12 version_is_compatible, write_dtype, UQFF_VERSION,
13};
14
15#[cfg(feature = "cuda")]
16use candle_core::{
17 cuda::cudarc::{
18 self,
19 driver::{CudaSlice, DevicePtr, DeviceRepr},
20 },
21 CudaDevice, Device, Tensor,
22};
23
24#[cfg(feature = "cuda")]
25pub(crate) fn get_cuda_device(x: &Tensor) -> candle_core::Result<&CudaDevice> {
26 match x.device() {
27 Device::Cuda(dev) => Ok(dev),
28 _ => candle_core::bail!("Expected CUDA device"),
29 }
30}
31
32#[cfg(feature = "cuda")]
33pub fn slice_ptr<T: DeviceRepr>(
34 v: &CudaSlice<T>,
35 lo: usize,
36) -> (u64, cudarc::driver::SyncOnDrop<'_>) {
37 let (_, guard) = v.device_ptr(v.stream());
38 let (ptr, _) = v.slice(lo..).device_ptr(v.stream());
39 (ptr, guard)
40}