mistralrs_quant/hqq/
optimize.rs1use candle_core::{DType, Result, Tensor};
2
3use super::{HqqAxis, HqqLayer, OPTIMIZER_HQQ_DEFAULT_STEPS};
4
5pub(crate) struct OptParams {
6 pub(crate) lp_norm: f64,
7 pub(crate) beta: f64,
8 pub(crate) kappa: f64,
9 pub(crate) iters: usize,
10}
11
12impl OptParams {
13 pub(crate) fn default(optimization_steps: Option<usize>) -> Self {
14 Self {
15 lp_norm: 0.7,
16 beta: 1e1,
17 kappa: 1.01,
18 iters: optimization_steps.unwrap_or(OPTIMIZER_HQQ_DEFAULT_STEPS),
19 }
20 }
21}
22
23pub(crate) struct OptResults {
24 pub(crate) wq: Tensor,
25 pub(crate) scale: Tensor,
26 pub(crate) zero: Tensor,
27}
28
29fn shrink_lp_op(x: &Tensor, beta: f64, lp_norm: f64) -> Result<Tensor> {
30 if lp_norm == 1. {
31 x.sign()?.broadcast_mul(&(x.abs()? - 1. / beta)?.relu()?)
32 } else {
33 let pow_exp = Tensor::new(lp_norm as f32 - 1., x.device())?
34 .broadcast_as(x.shape().clone())?
35 .to_dtype(x.dtype())?;
36 x.sign()?
37 .broadcast_mul(&(x.abs()? - ((1. / beta) * x.abs()?.pow(&pow_exp)?))?.relu()?)
38 }
39}
40
41impl HqqLayer {
42 pub(crate) fn optimize_weights_proximal_legacy(
44 tensor: &Tensor,
45 scale: &Tensor,
46 zero: Tensor,
47 min: f64,
48 max: f64,
49 axis: HqqAxis,
50 opt_params: OptParams,
51 ) -> Result<OptResults> {
52 let OptParams {
53 lp_norm,
54 mut beta,
55 kappa,
56 iters,
57 } = opt_params;
58
59 let wf = tensor.clone();
60 let scale = scale.to_dtype(wf.dtype())?;
61 let mut zero = zero.to_dtype(wf.dtype())?;
62
63 let mut best_error = 1e4;
64 for _ in 0..iters {
65 let wq = wf
66 .broadcast_mul(&scale)?
67 .broadcast_add(&zero)?
68 .round()?
69 .clamp(min, max)?;
70 let wr = wq.broadcast_sub(&zero)?.broadcast_div(&scale)?;
71 let we = shrink_lp_op(&(&wf - &wr)?, beta, lp_norm)?;
72
73 zero = (wq - (&wf - we)?.broadcast_mul(&scale)?)?.mean_keepdim(axis as usize)?;
74 beta *= kappa;
75
76 let current_error = (&wf - wr)?
77 .abs()?
78 .mean_all()?
79 .to_dtype(DType::F32)?
80 .to_scalar::<f32>()?;
81 if current_error < best_error {
82 best_error = current_error;
83 } else {
84 break;
85 }
86 }
87
88 let wq = tensor
89 .broadcast_mul(&scale)?
90 .broadcast_add(&zero)?
91 .round()?
92 .clamp(min, max)?;
93 Ok(OptResults { wq, scale, zero })
94 }
95}