mistralrs_quant/unquantized/
mod.rsuse std::{
borrow::Cow,
io::Cursor,
num::NonZeroUsize,
sync::{atomic::AtomicUsize, Arc},
};
use byteorder::{LittleEndian, ReadBytesExt};
use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
use candle_nn::Linear;
use crate::{
cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_HANDLE},
generate_isq, generate_isq_imatrix,
hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
utils::{deserialize_tensor, serialize_tensor, version_is_compatible, HQFF_VERSION},
FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, QuantMethod, QuantMethodConfig,
QuantizedSerde, QuantizedSerdeType,
};
#[derive(Debug)]
pub struct UnquantLinear {
w: Tensor,
b: Option<Tensor>,
stats: Option<ImatrixLayerStats>,
}
impl QuantMethod for UnquantLinear {
fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
where
Self: Sized,
{
match method {
QuantMethodConfig::Gguf { .. }
| QuantMethodConfig::Gptq { .. }
| QuantMethodConfig::Hqq { .. }
| QuantMethodConfig::Dummy
| QuantMethodConfig::FP8 { .. }
| QuantMethodConfig::Bnb { .. } => unreachable!(),
QuantMethodConfig::Unquantized(l) => Ok(Self {
w: l.weight().clone(),
b: l.bias().cloned(),
stats: None,
}),
}
}
fn dequantize_w(&self) -> Result<Tensor> {
Ok(self.w.clone())
}
fn forward(&self, a: &Tensor) -> Result<Tensor> {
maybe_init_cublas_lt_wrapper();
let w = match *a.dims() {
[b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
[bsize, _, _] => self.w.broadcast_left(bsize)?,
_ => self.w.clone(),
};
if let Some(stats) = &self.stats {
stats.process(a)?;
}
if let Some(b) = self.b.as_ref() {
let mut tgt_shape = a.dims().to_vec();
tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
match a.device().location() {
DeviceLocation::Cuda { .. } => {
if let (Device::Cuda(_), Some(cublaslt)) =
(a.device(), *CUBLASLT_HANDLE.lock().unwrap())
{
cublaslt
.batch_matmul(
a,
&w,
Some(&b.t()?.contiguous()?),
None,
Some(1.0),
None,
None,
)?
.t()
} else {
let mut out = b.contiguous()?;
a.matmul_with_alpha_beta(&w, &mut out, None)?;
Ok(out)
}
}
DeviceLocation::Metal { .. } => {
let mut out = b.contiguous()?.to_dtype(DType::F32)?;
a.to_dtype(DType::F32)?
.matmul_with_alpha_beta(&w.to_dtype(DType::F32)?.t()?, &mut out, None)
.unwrap();
out.to_dtype(a.dtype())
}
DeviceLocation::Cpu => {
let mut out = b.contiguous()?;
a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
Ok(out)
}
}
} else {
a.matmul(&w.t()?)
}
}
fn quantized_act_type(&self) -> Option<DType> {
None
}
fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
Ok(Arc::new(Self {
w: (&self.w + delta)?,
b: self.b.clone(),
stats: self.stats.clone(),
}))
}
fn dtype_and_device(&self) -> (DType, candle_core::Device) {
(self.w.dtype(), self.w.device().clone())
}
fn get_bias_mut(&mut self) -> Option<&mut Tensor> {
None
}
fn apply_isq(
self: Arc<Self>,
dtype: Option<IsqType>,
device: Device,
n_quantized: &AtomicUsize,
imatrix_weight: Option<Vec<f32>>,
) -> Result<Arc<dyn QuantMethod>> {
match dtype {
Some(IsqType::HQQ4 | IsqType::HQQ8) => {
if imatrix_weight.is_some() {
candle_core::bail!("HQQ does not support imatrix.");
}
n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let bits = match dtype.unwrap() {
IsqType::HQQ8 => HqqBits::Eight,
IsqType::HQQ4 => HqqBits::Four,
_ => unreachable!(),
};
let cfg = HqqConfig {
bits,
group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
axis: HqqAxis::Zero,
optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
round_zeros: false,
channel_wise: true,
};
let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
if let Some(bias) = &self.b {
let bias = bias
.to_device(&device)?
.to_dtype(res.dtype_and_device().0)?;
Ok(Arc::new(res.with_bias(bias)))
} else {
Ok(Arc::new(res))
}
}
Some(
IsqType::Q2K
| IsqType::Q3K
| IsqType::Q4K
| IsqType::Q4_0
| IsqType::Q4_1
| IsqType::Q5K
| IsqType::Q5_0
| IsqType::Q5_1
| IsqType::Q6K
| IsqType::Q8K
| IsqType::Q8_0
| IsqType::Q8_1,
) => {
let dtype: GgmlDType = dtype.unwrap().try_into()?;
let res = if let Some(imatrix_weight) = imatrix_weight {
generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized)
} else {
generate_isq!(self.w, device, dtype, n_quantized)
};
Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: res,
b: self
.b
.as_ref()
.map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
})?))
}
Some(IsqType::F8E4M3) => {
if imatrix_weight.is_some() {
candle_core::bail!("F8E4M3 does not support imatrix.");
}
let w = self.w.to_device(&device)?;
let b = if let Some(b) = &self.b {
Some(b.to_device(&device)?)
} else {
None
};
Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
lin: Linear::new(w, b),
dtype: DType::F8E4M3,
})?))
}
None => {
let w = self.w.to_device(&device)?;
let b = if let Some(b) = &self.b {
Some(b.to_device(&device)?)
} else {
None
};
Ok(Arc::new(UnquantLinear::new(
QuantMethodConfig::Unquantized(Linear::new(w, b)),
)?))
}
}
}
fn get_max_isq_cpu_threads(&self, dtype: IsqType) -> Option<NonZeroUsize> {
match dtype {
IsqType::HQQ4 | IsqType::HQQ8 => {
Some(1.try_into().unwrap())
}
IsqType::F8E4M3 => None,
IsqType::Q2K
| IsqType::Q3K
| IsqType::Q4K
| IsqType::Q4_0
| IsqType::Q4_1
| IsqType::Q5K
| IsqType::Q5_0
| IsqType::Q5_1
| IsqType::Q6K
| IsqType::Q8K
| IsqType::Q8_0
| IsqType::Q8_1 => None,
}
}
fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
Some((self.w.clone(), self.b.clone()))
}
fn begin_track_stats(&mut self) -> Result<()> {
self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
Ok(())
}
fn end_track_stats(&self) -> Result<Tensor> {
if let Some(stats) = &self.stats {
let imatrix = stats.compute_imatrix()?;
stats.clear()?;
Ok(imatrix)
} else {
candle_core::bail!("`{}` does not support tracking stats.", self.name())
}
}
fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>> {
Ok(self.clone())
}
}
impl QuantizedSerde for UnquantLinear {
fn isq_serde_supported(&self) -> bool {
true
}
fn name(&self) -> &'static str {
"unquant-linear"
}
fn serialize(&self) -> Result<Cow<[u8]>> {
let mut buffer = Vec::new();
buffer.extend(&HQFF_VERSION.to_le_bytes());
buffer.push(QuantizedSerdeType::Unquant as u8);
buffer.push(self.b.is_some() as u8);
serialize_tensor(&mut buffer, &self.w)?;
if let Some(bias) = &self.b {
serialize_tensor(&mut buffer, bias)?;
}
Ok(Cow::from(buffer))
}
fn deserialize(data: Cow<[u8]>, device: &Device) -> Result<Arc<dyn QuantMethod>>
where
Self: Sized,
{
let mut buffer = Cursor::new(data.to_vec());
let version = buffer.read_u32::<LittleEndian>()?;
if let Err(e) = version_is_compatible(version) {
return Err(candle_core::Error::wrap(e));
}
let isq_type = buffer.read_u8()? as usize;
if isq_type != QuantizedSerdeType::Unquant as usize {
candle_core::bail!(
"ISQ type ({isq_type}) doesn't match expected type {}",
QuantizedSerdeType::Unquant as usize
);
}
let has_bias = buffer.read_u8()? != 0;
let w = deserialize_tensor(&mut buffer, device)?;
let b = if has_bias {
Some(deserialize_tensor(&mut buffer, device)?)
} else {
None
};
Ok(Arc::new(Self { w, b, stats: None }))
}
}