mistralrs_quant/hqq/
quantize.rs1use candle_core::{DType, Device, Result, Tensor};
2
3use crate::hqq::optimize::OptResults;
4
5use super::{optimize::OptParams, HqqAxis, HqqConfig, HqqLayer};
6
7impl HqqLayer {
8 pub fn quantize(input: &Tensor, device: &Device, cfg: HqqConfig) -> Result<Self> {
10 let group_size: usize = cfg.group_size.into();
11 if input.elem_count() % group_size != 0 {
12 candle_core::bail!("`group_size` should be divisible by the tensor number of elements, which are {}, got a group size of {group_size}.", input.elem_count());
13 }
14
15 let mut w = input.clone().to_dtype(DType::F32)?;
16
17 w = if cfg.channel_wise {
19 match cfg.axis {
20 HqqAxis::One => w.reshape(((), group_size))?,
21 HqqAxis::Zero => w.reshape((group_size, ()))?,
22 }
23 } else {
24 w
25 };
26
27 let (min, max) = if !cfg.channel_wise {
29 let mut min = w.min(0)?;
31 let mut max = w.max(0)?;
32 while !min.dims().is_empty() {
33 min = min.min(0)?;
34 max = max.max(0)?;
35 }
36 (min, max)
37 } else {
38 (
39 w.min_keepdim(cfg.axis as usize)?,
40 w.max_keepdim(cfg.axis as usize)?,
41 )
42 };
43
44 let max_v = (2f64.powf(cfg.bits as usize as f64) - 1.).round();
45
46 let scale = (max_v / (max - &min)?)?.clamp(0., 2e4)?;
49 let mut zero = (min.neg()? * &scale)?;
50
51 if cfg.round_zeros {
52 zero = zero.round()?;
53 }
54
55 let OptResults { wq, scale, zero } = Self::optimize_weights_proximal_legacy(
64 &w,
65 &scale,
66 zero,
67 0.,
68 max_v,
69 cfg.axis,
70 OptParams::default(cfg.optimization_steps),
71 )?;
72
73 let quant_w = cfg.bits.bitpack_type()(wq)?.to_device(device)?;
74
75 let this = Self {
76 w_q: quant_w,
77 zeros: zero.to_device(device)?,
78 scales: (1.0 / scale)?.to_device(device)?,
79 bias: None,
80 w_shape: input.shape().clone(),
81 cfg,
82 };
83 Ok(this)
84 }
85}
86
87#[cfg(test)]
88mod test {
89 use candle_core::{Device, Result, Tensor};
90
91 #[test]
92 fn test_quantize_hqq() -> Result<()> {
93 use candle_core::DType;
94
95 use crate::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
96
97 #[cfg(not(feature = "metal"))]
98 let dev = Device::cuda_if_available(0)?;
99 #[cfg(feature = "metal")]
100 let dev = Device::new_metal(0)?;
101
102 let data = Tensor::rand(0f32, 1f32, (10, 10), &dev)?.to_dtype(DType::F32)?;
103 let _hqq = HqqLayer::quantize(
104 &data,
105 &dev,
106 HqqConfig {
107 bits: HqqBits::Three,
108 group_size: 10.try_into()?,
109 axis: HqqAxis::Zero,
110 optimization_steps: None,
111 round_zeros: false,
112 channel_wise: true,
113 },
114 )?;
115
116 Ok(())
123 }
124}