mistralrs_quant/gptq/
gptq_cpu.rs

1use crate::{
2    DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
3    QuantizedSerde, ShardedVarBuilder,
4};
5use candle_core::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12    fn new(method: QuantMethodConfig) -> Result<Self>
13    where
14        Self: Sized,
15    {
16        match method {
17            QuantMethodConfig::Gptq { .. } => candle_core::bail!("GPTQ is only supported on CUDA."),
18            QuantMethodConfig::Gguf { .. }
19            | QuantMethodConfig::Unquantized(_)
20            | QuantMethodConfig::Hqq { .. }
21            | QuantMethodConfig::Dummy
22            | QuantMethodConfig::FP8 { .. }
23            | QuantMethodConfig::Bnb { .. }
24            | QuantMethodConfig::BlockwiseFP8 { .. }
25            | QuantMethodConfig::Afq { .. } => {
26                unreachable!()
27            }
28        }
29    }
30
31    fn dequantize_w(&self) -> Result<Tensor> {
32        todo!()
33    }
34
35    fn forward(&self, _a: &Tensor) -> Result<Tensor> {
36        todo!()
37    }
38
39    fn quantized_act_type(&self) -> Option<DType> {
40        todo!()
41    }
42
43    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
44        todo!()
45    }
46
47    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
48        todo!()
49    }
50
51    fn apply_isq(
52        self: Arc<Self>,
53        _dtype: Option<IsqType>,
54        _device: Device,
55        _n_quantized: &AtomicUsize,
56        _imatrix_weight: Option<Vec<f32>>,
57        _guard: QuantizeOntoGuard,
58    ) -> Result<Arc<dyn QuantMethod>> {
59        todo!()
60    }
61}
62
63impl QuantizedSerde for GptqLayer {
64    fn name(&self) -> &'static str {
65        "gptq"
66    }
67}
68
69macro_rules! pack_factor {
70    ($bits:expr) => {
71        32 / $bits
72    };
73}
74
75pub fn gptq_linear(
76    in_dim: usize,
77    out_dim: usize,
78    config: &QuantizedConfig,
79    vb: ShardedVarBuilder,
80) -> Result<Arc<dyn QuantMethod>> {
81    let QuantizedConfig::Gptq {
82        bits,
83        group_size,
84        checkpoint_format: _,
85    } = config
86    else {
87        candle_core::bail!("Unexpected quantization config.")
88    };
89
90    // Handle the case where the layer is dummy (no tensors)
91    if !(vb.contains_tensor("qweight")
92        && vb.contains_tensor("qzeros")
93        && vb.contains_tensor("g_idx")
94        && vb.contains_tensor("scales"))
95    {
96        let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
97        return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
98    }
99
100    let qweight = vb.get_with_hints_dtype(
101        (in_dim / pack_factor!(bits), out_dim),
102        "qweight",
103        Default::default(),
104        DType::I32,
105    )?;
106    let scale_and_zero_size = in_dim / group_size;
107    let qzeros = vb.get_with_hints_dtype(
108        (scale_and_zero_size, out_dim / pack_factor!(bits)),
109        "qzeros",
110        Default::default(),
111        DType::I32,
112    )?;
113    let g_idx = vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?;
114    let scales = vb.get_with_hints_dtype(
115        (scale_and_zero_size, out_dim),
116        "scales",
117        Default::default(),
118        DType::F16,
119    )?;
120    let bias = if vb.contains_tensor("bias") {
121        Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
122    } else {
123        None
124    };
125
126    let config = QuantMethodConfig::Gptq {
127        bits: *bits as i32,
128        use_exllama: false,
129        q_weight: qweight,
130        gptq_qzeros: Some(qzeros),
131        gptq_scales: scales,
132        g_idx: Some(g_idx),
133        bias,
134        workspace: None,
135        is_marlin: false,
136    };
137    Ok(Arc::new(GptqLayer::new(config)?))
138}