mistralrs_quant/hqq/
quantize.rs

1use candle_core::{DType, Device, Result, Tensor};
2
3use crate::hqq::optimize::OptResults;
4
5use super::{optimize::OptParams, HqqAxis, HqqConfig, HqqLayer};
6
7impl HqqLayer {
8    /// Quantize the model into HQQ
9    pub fn quantize(input: &Tensor, device: &Device, cfg: HqqConfig) -> Result<Self> {
10        let group_size: usize = cfg.group_size.into();
11        if input.elem_count() % group_size != 0 {
12            candle_core::bail!("`group_size` should be divisible by the tensor number of elements, which are {}, got a group size of {group_size}.", input.elem_count());
13        }
14
15        let mut w = input.clone().to_dtype(DType::F32)?;
16
17        // Reshape for grouping
18        w = if cfg.channel_wise {
19            match cfg.axis {
20                HqqAxis::One => w.reshape(((), group_size))?,
21                HqqAxis::Zero => w.reshape((group_size, ()))?,
22            }
23        } else {
24            w
25        };
26
27        // Get min and max valyes
28        let (min, max) = if !cfg.channel_wise {
29            // TODO we need min_all
30            let mut min = w.min(0)?;
31            let mut max = w.max(0)?;
32            while !min.dims().is_empty() {
33                min = min.min(0)?;
34                max = max.max(0)?;
35            }
36            (min, max)
37        } else {
38            (
39                w.min_keepdim(cfg.axis as usize)?,
40                w.max_keepdim(cfg.axis as usize)?,
41            )
42        };
43
44        let max_v = (2f64.powf(cfg.bits as usize as f64) - 1.).round();
45
46        // Note: here using the inverse of the scale to avoid division, quantize via W * scale + zero, scale is inverted later!
47        // Clamp to avoid half precision problems
48        let scale = (max_v / (max - &min)?)?.clamp(0., 2e4)?;
49        let mut zero = (min.neg()? * &scale)?;
50
51        if cfg.round_zeros {
52            zero = zero.round()?;
53        }
54
55        // We only support using optimization!
56        // let (wq, scale, zero) = (
57        //         w.broadcast_mul(&scale)?
58        //             .broadcast_add(&zero)?
59        //             .clamp(0., max_v)?,
60        //         scale,
61        //         zero,
62        //     );
63        let OptResults { wq, scale, zero } = Self::optimize_weights_proximal_legacy(
64            &w,
65            &scale,
66            zero,
67            0.,
68            max_v,
69            cfg.axis,
70            OptParams::default(cfg.optimization_steps),
71        )?;
72
73        let quant_w = cfg.bits.bitpack_type()(wq)?.to_device(device)?;
74
75        let this = Self {
76            w_q: quant_w,
77            zeros: zero.to_device(device)?,
78            scales: (1.0 / scale)?.to_device(device)?,
79            bias: None,
80            w_shape: input.shape().clone(),
81            cfg,
82        };
83        Ok(this)
84    }
85}
86
87#[cfg(test)]
88mod test {
89    use candle_core::{Device, Result, Tensor};
90
91    #[test]
92    fn test_quantize_hqq() -> Result<()> {
93        use candle_core::DType;
94
95        use crate::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
96
97        #[cfg(not(feature = "metal"))]
98        let dev = Device::cuda_if_available(0)?;
99        #[cfg(feature = "metal")]
100        let dev = Device::new_metal(0)?;
101
102        let data = Tensor::rand(0f32, 1f32, (10, 10), &dev)?.to_dtype(DType::F32)?;
103        let _hqq = HqqLayer::quantize(
104            &data,
105            &dev,
106            HqqConfig {
107                bits: HqqBits::Three,
108                group_size: 10.try_into()?,
109                axis: HqqAxis::Zero,
110                optimization_steps: None,
111                round_zeros: false,
112                channel_wise: true,
113            },
114        )?;
115
116        // let dequant = hqq.dequantize()?;
117        // println!("Initial:\n{data}");
118        // println!("Dequantized:\n{dequant}");
119        // println!("Difference:\n{}", (&dequant - &data)?.abs()?);
120
121        // dbg!(&(&dequant - &data)?.abs()?.mean_all()?);
122        Ok(())
123    }
124}