mistralrs_quant/hqq/
optimize.rs

1use candle_core::{DType, Result, Tensor};
2
3use super::{HqqAxis, HqqLayer, OPTIMIZER_HQQ_DEFAULT_STEPS};
4
5pub(crate) struct OptParams {
6    pub(crate) lp_norm: f64,
7    pub(crate) beta: f64,
8    pub(crate) kappa: f64,
9    pub(crate) iters: usize,
10}
11
12impl OptParams {
13    pub(crate) fn default(optimization_steps: Option<usize>) -> Self {
14        Self {
15            lp_norm: 0.7,
16            beta: 1e1,
17            kappa: 1.01,
18            iters: optimization_steps.unwrap_or(OPTIMIZER_HQQ_DEFAULT_STEPS),
19        }
20    }
21}
22
23pub(crate) struct OptResults {
24    pub(crate) wq: Tensor,
25    pub(crate) scale: Tensor,
26    pub(crate) zero: Tensor,
27}
28
29fn shrink_lp_op(x: &Tensor, beta: f64, lp_norm: f64) -> Result<Tensor> {
30    if lp_norm == 1. {
31        x.sign()?.broadcast_mul(&(x.abs()? - 1. / beta)?.relu()?)
32    } else {
33        let pow_exp = Tensor::new(lp_norm as f32 - 1., x.device())?
34            .broadcast_as(x.shape().clone())?
35            .to_dtype(x.dtype())?;
36        x.sign()?
37            .broadcast_mul(&(x.abs()? - ((1. / beta) * x.abs()?.pow(&pow_exp)?))?.relu()?)
38    }
39}
40
41impl HqqLayer {
42    // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/core/optimize.py#L194
43    pub(crate) fn optimize_weights_proximal_legacy(
44        tensor: &Tensor,
45        scale: &Tensor,
46        zero: Tensor,
47        min: f64,
48        max: f64,
49        axis: HqqAxis,
50        opt_params: OptParams,
51    ) -> Result<OptResults> {
52        let OptParams {
53            lp_norm,
54            mut beta,
55            kappa,
56            iters,
57        } = opt_params;
58
59        let wf = tensor.clone();
60        let scale = scale.to_dtype(wf.dtype())?;
61        let mut zero = zero.to_dtype(wf.dtype())?;
62
63        let mut best_error = 1e4;
64        for _ in 0..iters {
65            let wq = wf
66                .broadcast_mul(&scale)?
67                .broadcast_add(&zero)?
68                .round()?
69                .clamp(min, max)?;
70            let wr = wq.broadcast_sub(&zero)?.broadcast_div(&scale)?;
71            let we = shrink_lp_op(&(&wf - &wr)?, beta, lp_norm)?;
72
73            zero = (wq - (&wf - we)?.broadcast_mul(&scale)?)?.mean_keepdim(axis as usize)?;
74            beta *= kappa;
75
76            let current_error = (&wf - wr)?
77                .abs()?
78                .mean_all()?
79                .to_dtype(DType::F32)?
80                .to_scalar::<f32>()?;
81            if current_error < best_error {
82                best_error = current_error;
83            } else {
84                break;
85            }
86        }
87
88        let wq = tensor
89            .broadcast_mul(&scale)?
90            .broadcast_add(&zero)?
91            .round()?
92            .clamp(min, max)?;
93        Ok(OptResults { wq, scale, zero })
94    }
95}