use std::{
borrow::Cow,
fmt::{Debug, Display},
num::NonZeroUsize,
sync::{atomic::AtomicUsize, Arc},
};
use candle_core::{
quantized::{GgmlDType, QTensor},
DType, Device, Result, Tensor,
};
#[cfg(feature = "metal")]
mod metal_kernels;
mod bitsandbytes;
mod cublaslt;
mod dummy;
mod fp8;
mod gguf;
mod gptq;
mod hqq;
mod imatrix;
mod unquantized;
mod utils;
pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
pub use dummy::DummyLayer;
pub use fp8::FP8Linear;
pub use gguf::GgufMatMul;
use gptq::gptq_linear;
pub use gptq::GptqLayer;
pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
pub use imatrix::ImatrixLayerStats;
pub use unquantized::UnquantLinear;
use candle_nn::{Linear, Module, VarBuilder};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub enum QuantMethodType {
#[serde(rename = "gptq")]
Gptq,
#[serde(rename = "unreachable")]
Unreachable,
#[default]
#[serde(rename = "bitsandbytes")]
Bitsandbytes,
}
impl Display for QuantMethodType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Gptq => write!(f, "GPTQ"),
Self::Bitsandbytes => write!(f, "bnb"),
Self::Unreachable => write!(f, "unreachable",),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct QuantizedConfig {
pub bits: Option<usize>,
pub group_size: Option<usize>,
pub checkpoint_format: Option<String>,
pub bnb_4bit_quant_type: Option<String>,
pub quant_method: QuantMethodType,
}
impl QuantizedConfig {
pub fn get_bits_name(&self, _vb: &VarBuilder) -> String {
match self.bits {
Some(bits) => format!("{bits} bits"),
None => {
self.bnb_4bit_quant_type
.clone()
.unwrap_or("int8".to_string())
}
}
}
}
#[derive(Debug, Clone)]
pub enum QuantMethodConfig {
Gptq {
bits: i32,
use_exllama: bool,
q_weight: Tensor,
gptq_qzeros: Option<Tensor>,
gptq_scales: Tensor,
g_idx: Option<Tensor>,
bias: Option<Tensor>,
workspace: Option<Tensor>,
is_marlin: bool,
},
Gguf {
q_weight: Arc<QTensor>,
b: Option<Tensor>,
},
Unquantized(Linear),
Hqq {
tensor: Tensor,
bits: HqqBits,
group_size: NonZeroUsize,
axis: HqqAxis,
optimization_steps: Option<usize>,
round_zeros: Option<bool>,
channel_wise: Option<bool>,
bias: Option<Tensor>,
},
Dummy,
FP8 {
lin: Linear,
dtype: DType,
},
Bnb {
weight: Tensor,
bias: Option<Tensor>,
params: BnbQuantParmas,
quant_ty: BnbQuantType,
},
}
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)]
pub enum IsqType {
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
HQQ8,
HQQ4,
F8E4M3,
}
impl TryFrom<IsqType> for GgmlDType {
type Error = candle_core::Error;
fn try_from(value: IsqType) -> Result<Self> {
let tp = match value {
IsqType::Q2K => Self::Q2K,
IsqType::Q3K => Self::Q3K,
IsqType::Q4K => Self::Q4K,
IsqType::Q4_0 => Self::Q4_0,
IsqType::Q4_1 => Self::Q4_1,
IsqType::Q5K => Self::Q5K,
IsqType::Q5_0 => Self::Q5_0,
IsqType::Q5_1 => Self::Q5_1,
IsqType::Q6K => Self::Q6K,
IsqType::Q8K => Self::Q8K,
IsqType::Q8_0 => Self::Q8_0,
IsqType::Q8_1 => Self::Q8_1,
_ => candle_core::bail!("Expected valid GGML ISQ type."),
};
#[cfg(feature = "cuda")]
{
if !matches!(
tp,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
) {
candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
}
}
Ok(tp)
}
}
pub enum QuantizedSerdeType {
Gguf = 0,
Unquant = 1,
Hqq = 2,
Fp8 = 3,
}
impl TryFrom<usize> for QuantizedSerdeType {
type Error = candle_core::Error;
fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
match value {
0 => Ok(Self::Gguf),
1 => Ok(Self::Unquant),
2 => Ok(Self::Hqq),
3 => Ok(Self::Fp8),
other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
}
}
}
pub trait QuantizedSerde {
fn name(&self) -> &'static str;
fn isq_serde_supported(&self) -> bool {
false
}
fn serialize(&self) -> Result<Cow<[u8]>> {
candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
}
fn deserialize(_data: Cow<[u8]>, _device: &Device) -> Result<Arc<dyn QuantMethod>>
where
Self: Sized,
{
candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
}
}
pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
fn new(method: QuantMethodConfig) -> Result<Self>
where
Self: Sized;
fn dequantize_w(&self) -> Result<Tensor>;
fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
let original_ty = a.dtype();
let a = if let Some(t) = self.quantized_act_type() {
a.to_dtype(t)?
} else {
a.clone()
};
self.forward(&a)?.to_dtype(original_ty)
}
fn forward(&self, a: &Tensor) -> Result<Tensor>;
fn forward_via_half(&self, a: &Tensor) -> Result<Tensor> {
self.forward(a)
}
fn quantized_act_type(&self) -> Option<DType>;
fn dtype_and_device(&self) -> (DType, Device);
fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
fn apply_isq(
self: Arc<Self>,
dtype: Option<IsqType>,
device: Device,
n_quantized: &AtomicUsize,
imatrix_weight: Option<Vec<f32>>,
) -> Result<Arc<dyn QuantMethod>>;
fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>>;
fn get_bias_mut(&mut self) -> Option<&mut Tensor>;
fn get_max_isq_cpu_threads(&self, dtype: IsqType) -> Option<NonZeroUsize>;
fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
None
}
fn begin_track_stats(&mut self) -> Result<()> {
candle_core::bail!("`{}` does not support tracking stats.", self.name())
}
fn end_track_stats(&self) -> Result<Tensor> {
candle_core::bail!("`{}` does not support tracking stats.", self.name())
}
}
impl Module for dyn QuantMethod {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
Self::forward(self, xs)
}
}
pub fn linear_no_bias(
in_dim: usize,
out_dim: usize,
config: &Option<QuantizedConfig>,
vb: VarBuilder,
) -> Result<Arc<dyn QuantMethod>> {
let layer = if let Some(quant_conf) = &config {
match quant_conf.quant_method {
QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
QuantMethodType::Bitsandbytes => {
Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
}
QuantMethodType::Unreachable => unreachable!(),
}
} else {
if !vb.contains_tensor("weight") {
let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
Arc::new(layer) as Arc<dyn QuantMethod>
} else {
let layer = candle_nn::linear_no_bias(in_dim, out_dim, vb)?;
let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(layer))?;
Arc::new(layer) as Arc<dyn QuantMethod>
}
};
Ok(layer)
}
pub fn linear(
in_dim: usize,
out_dim: usize,
config: &Option<QuantizedConfig>,
vb: VarBuilder,
) -> Result<Arc<dyn QuantMethod>> {
let layer = if let Some(quant_conf) = &config {
match quant_conf.quant_method {
QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
QuantMethodType::Bitsandbytes => {
Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
}
QuantMethodType::Unreachable => unreachable!(),
}
} else {
if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
Arc::new(layer) as Arc<dyn QuantMethod>
} else {
let layer = candle_nn::linear(in_dim, out_dim, vb)?;
let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(layer))?;
Arc::new(layer) as Arc<dyn QuantMethod>
}
};
Ok(layer)
}
pub fn linear_b(
in_dim: usize,
out_dim: usize,
bias: bool,
config: &Option<QuantizedConfig>,
vb: VarBuilder,
) -> Result<Arc<dyn QuantMethod>> {
if bias {
linear(in_dim, out_dim, config, vb)
} else {
linear_no_bias(in_dim, out_dim, config, vb)
}
}