mistralrs_quant/utils/
mod.rs1#[cfg(feature = "cuda")]
2mod ffi;
3pub(crate) mod isq;
4pub mod log;
5mod ops;
6mod uqff;
7
8#[cfg(feature = "cuda")]
9pub use ops::gptoss_swiglu_fused;
10#[cfg(feature = "cuda")]
11pub use ops::gptoss_swiglu_interleaved;
12#[cfg(feature = "cuda")]
13pub use ops::softmax_with_sinks;
14pub use ops::{BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp};
15pub use uqff::UQFF_QUANT_TYPE_OFFSET;
16pub(crate) use uqff::{
17 deserialize_tensor, fake_deserialize_tensor, read_dtype, serialize_tensor,
18 version_is_compatible, write_dtype, UQFF_VERSION,
19};
20
21#[cfg(feature = "cuda")]
22use candle_core::{
23 cuda::cudarc::{
24 self,
25 driver::{CudaSlice, DevicePtr, DeviceRepr},
26 },
27 CudaDevice, Device, Tensor,
28};
29
30#[cfg(feature = "cuda")]
31pub(crate) fn get_cuda_device(x: &Tensor) -> candle_core::Result<&CudaDevice> {
32 match x.device() {
33 Device::Cuda(dev) => Ok(dev),
34 _ => candle_core::bail!("Expected CUDA device"),
35 }
36}
37
38#[cfg(feature = "cuda")]
39pub fn slice_ptr<T: DeviceRepr>(
40 v: &CudaSlice<T>,
41 lo: usize,
42) -> (u64, cudarc::driver::SyncOnDrop<'_>) {
43 let (_, guard) = v.device_ptr(v.stream());
44 let (ptr, _) = v.slice(lo..).device_ptr(v.stream());
45 (ptr, guard)
46}