use crate::core::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D};
use k_quants::*;
use std::{borrow::Cow, mem, sync::Arc};
#[cfg(target_feature = "avx")]
pub mod avx;
mod dummy_cuda;
mod dummy_metal;
pub mod ggml_file;
pub mod gguf_file;
pub mod imatrix_file;
pub mod k_quants;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(not(feature = "metal"))]
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(not(feature = "cuda"))]
mod cuda {
pub use super::dummy_cuda::*;
}
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils;
use half::{bf16, f16};
pub use k_quants::GgmlType;
pub struct QTensor {
storage: QStorage,
shape: Shape,
}
impl Device {
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
match self {
Device::Cpu => {
let storage = dtype.cpu_zeros(elem_count);
Ok(QStorage::Cpu(storage))
}
Device::Metal(metal) => {
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
Ok(QStorage::Metal(storage))
}
Device::Cuda(cuda) => {
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
Ok(QStorage::Cuda(storage))
}
}
}
}
pub enum QStorage {
Cpu(Box<dyn QuantizedType>),
Metal(metal::QMetalStorage),
Cuda(cuda::QCudaStorage),
}
impl QStorage {
fn block_size(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.block_size(),
QStorage::Metal(storage) => storage.dtype().block_size(),
QStorage::Cuda(storage) => storage.dtype().block_size(),
}
}
fn dtype(&self) -> GgmlDType {
match self {
QStorage::Cpu(storage) => storage.dtype(),
QStorage::Metal(storage) => storage.dtype(),
QStorage::Cuda(storage) => storage.dtype(),
}
}
fn device(&self) -> Device {
match self {
QStorage::Cpu(_storage) => Device::Cpu,
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
}
}
fn size_in_bytes(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
}
}
fn quantize(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid quantize storage locations do not match"),
}
Ok(())
}
fn quantize_imatrix(
&mut self,
src: &Storage,
imatrix_weights: &[f32],
n_per_row: usize,
) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row)?;
}
(QStorage::Metal(storage), Storage::Metal(src)) => {
storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
}
(QStorage::Cuda(storage), Storage::Cuda(src)) => {
storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
}
_ => crate::bail!("Invalid quantize storage locations do not match"),
}
Ok(())
}
fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
(QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
(QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
_ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
}
Ok(())
}
fn quantize_imatrix_onto(
&mut self,
src: &Storage,
imatrix_weights: &[f32],
n_per_row: usize,
) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row)?;
}
(QStorage::Metal(storage), Storage::Cpu(src)) => {
storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
}
(QStorage::Cuda(storage), Storage::Cpu(src)) => {
storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
}
_ => crate::bail!("Invalid quantize storage locations do not match"),
}
Ok(())
}
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
}
}
fn data(&self) -> Result<Cow<[u8]>> {
match self {
QStorage::Cpu(storage) => {
let data_ptr = storage.as_ptr();
let size_in_bytes = storage.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data))
}
QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)),
QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GgmlDType {
F32,
F16,
BF16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
}
impl GgmlDType {
pub(crate) fn from_u32(u: u32) -> Result<Self> {
let dtype = match u {
0 => Self::F32,
1 => Self::F16,
2 => Self::Q4_0,
3 => Self::Q4_1,
6 => Self::Q5_0,
7 => Self::Q5_1,
8 => Self::Q8_0,
9 => Self::Q8_1,
10 => Self::Q2K,
11 => Self::Q3K,
12 => Self::Q4K,
13 => Self::Q5K,
14 => Self::Q6K,
15 => Self::Q8K,
30 => Self::BF16,
_ => crate::bail!("unknown dtype for tensor {u}"),
};
Ok(dtype)
}
pub(crate) fn to_u32(self) -> u32 {
match self {
Self::F32 => 0,
Self::F16 => 1,
Self::Q4_0 => 2,
Self::Q4_1 => 3,
Self::Q5_0 => 6,
Self::Q5_1 => 7,
Self::Q8_0 => 8,
Self::Q8_1 => 9,
Self::Q2K => 10,
Self::Q3K => 11,
Self::Q4K => 12,
Self::Q5K => 13,
Self::Q6K => 14,
Self::Q8K => 15,
Self::BF16 => 30,
}
}
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
match self {
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
}
}
pub fn type_size(&self) -> usize {
use k_quants::*;
match self {
Self::F32 => 4,
Self::F16 | Self::BF16 => 2,
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
Self::Q8K => std::mem::size_of::<BlockQ8K>(),
}
}
pub fn block_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 | Self::BF16 => 1,
Self::Q4_0 => k_quants::QK4_0,
Self::Q4_1 => k_quants::QK4_1,
Self::Q5_0 => k_quants::QK5_0,
Self::Q5_1 => k_quants::QK5_1,
Self::Q8_0 => k_quants::QK8_0,
Self::Q8_1 => k_quants::QK8_1,
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
}
}
}
pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
fn storage_size_in_bytes(&self) -> usize;
fn as_ptr(&self) -> *const u8;
fn block_size(&self) -> usize;
#[allow(clippy::wrong_self_convention)]
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
#[allow(clippy::wrong_self_convention)]
fn from_float_imatrix(
&mut self,
xs: &[f32],
imatrix_weights: &[f32],
n_per_row: usize,
) -> Result<()>;
fn size(&self) -> usize;
}
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
}
fn size(&self) -> usize {
self.len() * core::mem::size_of::<T>()
}
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
T::from_float(xs, self)
}
fn from_float_imatrix(
&mut self,
xs: &[f32],
imatrix_weights: &[f32],
n_per_row: usize,
) -> Result<()> {
T::from_float_imatrix(xs, self, imatrix_weights, n_per_row)
}
fn dtype(&self) -> GgmlDType {
T::DTYPE
}
fn block_size(&self) -> usize {
T::BLCK_SIZE
}
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
let mut ys = vec![0.0f32; elem_count];
T::to_float(self.as_slice(), &mut ys)?;
Ok(CpuStorage::F32(ys))
}
fn storage_size_in_bytes(&self) -> usize {
self.len() * std::mem::size_of::<T>()
}
fn as_ptr(&self) -> *const u8 {
self.as_ptr() as *const u8
}
}
impl std::fmt::Debug for QTensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
}
}
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
let dims = shape.dims();
if dims.is_empty() {
crate::bail!("scalar tensor cannot be quantized {shape:?}")
}
if dims[dims.len() - 1] % block_size != 0 {
crate::bail!(
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
block_size
)
}
Ok(())
}
impl QTensor {
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
let shape = shape.into();
check_shape(&shape, storage.block_size())?;
Ok(Self { storage, shape })
}
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
let shape = src.shape();
let block_size = dtype.block_size();
check_shape(shape, block_size)?;
let src = src.to_dtype(crate::core::DType::F32)?.flatten_all()?;
let elem_count = shape.elem_count();
if elem_count % block_size != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
block_size
)
}
let mut storage = src.device().qzeros(elem_count, dtype)?;
storage.quantize(&src.storage())?;
Ok(Self {
storage,
shape: shape.clone(),
})
}
pub fn quantize_imatrix(
src: &Tensor,
imatrix_weights: &[f32],
dtype: GgmlDType,
) -> Result<Self> {
let n_per_row = src.dim(D::Minus1)?;
if imatrix_weights.len() != n_per_row {
crate::bail!(
"imatrix weights must have the same length {} as the last dim of src {}",
imatrix_weights.len(),
src.dim(D::Minus1)?
);
}
let shape = src.shape();
let block_size = dtype.block_size();
check_shape(shape, block_size)?;
let src = src.to_dtype(crate::core::DType::F32)?.flatten_all()?;
let elem_count = shape.elem_count();
if elem_count % block_size != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
block_size
);
}
let mut storage = src.device().qzeros(elem_count, dtype)?;
storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?;
Ok(Self {
storage,
shape: shape.clone(),
})
}
pub fn quantize_imatrix_onto(
src: &Tensor,
imatrix_weights: &[f32],
dtype: GgmlDType,
dev: &Device,
) -> Result<Self> {
if !src.device().is_cpu() {
crate::bail!(
"`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
src.device()
)
}
let n_per_row = src.dim(D::Minus1)?;
if imatrix_weights.len() != n_per_row {
crate::bail!(
"imatrix weights must have the same length {} as the last dim of src {}",
imatrix_weights.len(),
src.dim(D::Minus1)?
);
}
let shape = src.shape();
let block_size = dtype.block_size();
check_shape(shape, block_size)?;
let src = src.to_dtype(crate::core::DType::F32)?.flatten_all()?;
let elem_count = shape.elem_count();
if elem_count % block_size != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
block_size
)
}
let mut storage = dev.qzeros(elem_count, dtype)?;
storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?;
Ok(Self {
storage,
shape: shape.clone(),
})
}
pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
if !src.device().is_cpu() {
crate::bail!(
"`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
src.device()
)
}
let shape = src.shape();
let block_size = dtype.block_size();
check_shape(shape, block_size)?;
let src = src.to_dtype(crate::core::DType::F32)?.flatten_all()?;
let elem_count = shape.elem_count();
if elem_count % block_size != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
block_size
)
}
let mut storage = dev.qzeros(elem_count, dtype)?;
storage.quantize_onto(&src.storage())?;
Ok(Self {
storage,
shape: shape.clone(),
})
}
pub fn dtype(&self) -> GgmlDType {
self.storage.dtype()
}
pub fn device(&self) -> Device {
self.storage.device()
}
pub fn rank(&self) -> usize {
self.shape.rank()
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::core::op::BackpropOp::none();
crate::core::tensor::from_storage(storage, self.shape.clone(), none, false)
.to_device(device)
}
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
match &self.storage {
QStorage::Cuda(s) => {
let s = s.dequantize_f16(self.shape.elem_count())?;
let none = crate::core::op::BackpropOp::none();
crate::core::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
.to_device(device)
}
_ => {
let s = self.dequantize(device)?.to_dtype(crate::core::DType::F16)?;
Ok(s)
}
}
}
pub fn storage_size_in_bytes(&self) -> usize {
self.storage.size_in_bytes()
}
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
self.storage.data()
}
}
#[derive(Clone, Debug)]
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
TensorF16(Tensor),
}
thread_local! {
static DEQUANTIZE_ALL: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
thread_local! {
static DEQUANTIZE_ALL_F16: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
unsafe fn transmute_to_blocks<T: GgmlType>(data: Vec<u8>, elem_count: usize) -> Vec<T> {
let raw_ptr = data.as_ptr();
let capacity = data.capacity();
mem::forget(data); unsafe { Vec::from_raw_parts(raw_ptr as *mut T, elem_count / T::BLCK_SIZE, capacity) }
}
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
_ => DEQUANTIZE_ALL.with(|b| *b),
};
let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?;
Self::Tensor(tensor)
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
Self::TensorF16(tensor)
} else {
Self::QTensor(qtensor)
};
Ok(t)
}
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
Self::from_arc(std::sync::Arc::new(qtensor))
}
pub fn dequantize_f16(&self) -> Result<Tensor> {
match self {
Self::QTensor(t) => t.dequantize_f16(&t.device()),
Self::Tensor(t) => t.to_dtype(DType::F16),
Self::TensorF16(t) => Ok(t.clone()),
}
}
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
let w = self.dequantize_f16()?;
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
pub fn to_device(&self, dev: &Device) -> Result<Self> {
match self {
Self::QTensor(q) => {
let storage = &q.storage;
let shape = q.shape.clone();
let data = q.data()?;
let elem_count = shape.elem_count();
let new_s = match (storage, dev) {
(QStorage::Cpu(_), Device::Cpu) => {
return Ok(Self::QTensor(q.clone()));
}
(QStorage::Metal(_), Device::Cpu) | (QStorage::Cuda(_), Device::Cpu) => {
let new_data: Box<dyn QuantizedType> = unsafe {
match q.dtype() {
GgmlDType::F32 => {
Box::new(transmute_to_blocks::<f32>(data.to_vec(), elem_count))
}
GgmlDType::F16 => {
Box::new(transmute_to_blocks::<f16>(data.to_vec(), elem_count))
}
GgmlDType::Q4_0 => Box::new(transmute_to_blocks::<BlockQ4_0>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q4_1 => Box::new(transmute_to_blocks::<BlockQ4_1>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q5_0 => Box::new(transmute_to_blocks::<BlockQ5_0>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q5_1 => Box::new(transmute_to_blocks::<BlockQ5_1>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q8_0 => Box::new(transmute_to_blocks::<BlockQ8_0>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q8_1 => Box::new(transmute_to_blocks::<BlockQ8_1>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q2K => Box::new(transmute_to_blocks::<BlockQ2K>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q3K => Box::new(transmute_to_blocks::<BlockQ3K>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q4K => Box::new(transmute_to_blocks::<BlockQ4K>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q5K => Box::new(transmute_to_blocks::<BlockQ5K>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q6K => Box::new(transmute_to_blocks::<BlockQ6K>(
data.to_vec(),
elem_count,
)),
GgmlDType::Q8K => Box::new(transmute_to_blocks::<BlockQ8K>(
data.to_vec(),
elem_count,
)),
GgmlDType::BF16 => {
Box::new(transmute_to_blocks::<bf16>(data.to_vec(), elem_count))
}
}
};
QStorage::Cpu(new_data)
}
#[cfg(feature = "cuda")]
(QStorage::Cpu(_) | QStorage::Metal(_), Device::Cuda(d)) => {
use crate::core::cuda_backend::WrapErr;
let slice = d.htod_sync_copy(&data).w()?;
QStorage::Cuda(cuda::QCudaStorage::from_buffer(slice, d, q.dtype())?)
}
#[cfg(not(feature = "cuda"))]
(QStorage::Cpu(_) | QStorage::Metal(_), Device::Cuda(_)) => {
return Err(crate::core::Error::NotCompiledWithCudaSupport);
}
#[cfg(feature = "cuda")]
(QStorage::Cuda(_), Device::Cuda(d)) => {
use super::backend::BackendDevice;
use crate::core::cuda_backend::WrapErr;
if d.same_device(q.device().as_cuda_device()?) {
return Ok(Self::QTensor(q.clone()));
}
let slice = d.htod_sync_copy(&data).w()?;
QStorage::Cuda(cuda::QCudaStorage::from_buffer(slice, d, q.dtype())?)
}
#[cfg(not(feature = "cuda"))]
(QStorage::Cuda(_), Device::Cuda(_)) => {
return Err(crate::core::Error::NotCompiledWithCudaSupport);
}
#[cfg(feature = "metal")]
(QStorage::Cpu(_) | QStorage::Cuda(_), Device::Metal(d)) => {
let buffer = d.new_buffer_with_data(&data)?;
QStorage::Metal(metal::QMetalStorage::from_buffer(buffer, d, q.dtype())?)
}
#[cfg(not(feature = "metal"))]
(QStorage::Cpu(_) | QStorage::Cuda(_), Device::Metal(_)) => {
return Err(crate::core::Error::NotCompiledWithMetalSupport);
}
#[cfg(feature = "metal")]
(QStorage::Metal(_), Device::Metal(d)) => {
use super::backend::BackendDevice;
if d.same_device(q.device().as_metal_device()?) {
return Ok(Self::QTensor(q.clone()));
}
let buffer = d.new_buffer_with_data(&data)?;
QStorage::Metal(metal::QMetalStorage::from_buffer(buffer, d, q.dtype())?)
}
#[cfg(not(feature = "metal"))]
(QStorage::Metal(_), Device::Metal(_)) => {
return Err(crate::core::Error::NotCompiledWithMetalSupport);
}
};
Ok(Self::QTensor(Arc::new(QTensor::new(new_s, shape)?)))
}
Self::Tensor(t) => Ok(Self::Tensor(t.to_device(dev)?)),
Self::TensorF16(t) => Ok(Self::TensorF16(t.to_device(dev)?)),
}
}
pub fn size_in_bytes(&self) -> Result<usize> {
match self {
Self::Tensor(t) | Self::TensorF16(t) => Ok(t.dtype().size_in_bytes() * t.elem_count()),
Self::QTensor(q) => {
let len = match &q.storage {
QStorage::Cpu(_) => q.data()?.len(),
QStorage::Cuda(c) => c.storage_size_in_bytes(),
QStorage::Metal(m) => m.storage_size_in_bytes(),
};
Ok(len)
}
}
}
}
impl crate::core::CustomOp1 for QTensor {
fn name(&self) -> &'static str {
"qmatmul"
}
fn cpu_fwd(
&self,
storage: &crate::core::CpuStorage,
layout: &crate::core::Layout,
) -> Result<(crate::core::CpuStorage, Shape)> {
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
let (n, k) = self.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
#[allow(clippy::infallible_destructuring_match)]
let self_storage = match &self.storage {
QStorage::Cpu(storage) => storage,
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
};
let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
Ok((crate::core::CpuStorage::F32(dst_storage), dst_shape))
}
fn metal_fwd(
&self,
storage: &crate::core::MetalStorage,
layout: &crate::core::Layout,
) -> Result<(crate::core::MetalStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Metal(metal) => metal,
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
fn cuda_fwd(
&self,
storage: &crate::core::CudaStorage,
layout: &crate::core::Layout,
) -> Result<(crate::core::CudaStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Cuda(cuda) => cuda,
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
}
impl crate::core::Module for QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(w) => {
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.matmul(&w)
}
Self::TensorF16(w) => {
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
}
}
}