mistralrs_quant/gptq/
gptq_cpu.rs

1use crate::{
2    DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
3    QuantizedSerde, ShardedVarBuilder,
4};
5use candle_core::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12    fn new(method: QuantMethodConfig) -> Result<Self>
13    where
14        Self: Sized,
15    {
16        match method {
17            QuantMethodConfig::GptqAwq { .. } => {
18                candle_core::bail!("GPTQ is only supported on CUDA.")
19            }
20            QuantMethodConfig::Gguf { .. }
21            | QuantMethodConfig::Unquantized(_)
22            | QuantMethodConfig::Hqq { .. }
23            | QuantMethodConfig::Dummy
24            | QuantMethodConfig::FP8 { .. }
25            | QuantMethodConfig::Bnb { .. }
26            | QuantMethodConfig::BlockwiseFP8 { .. }
27            | QuantMethodConfig::Afq { .. } => {
28                unreachable!()
29            }
30        }
31    }
32
33    fn dequantize_w(&self) -> Result<Tensor> {
34        todo!()
35    }
36
37    fn forward(&self, _a: &Tensor) -> Result<Tensor> {
38        todo!()
39    }
40
41    fn quantized_act_type(&self) -> Option<DType> {
42        todo!()
43    }
44
45    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
46        todo!()
47    }
48
49    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
50        todo!()
51    }
52
53    fn apply_isq(
54        self: Arc<Self>,
55        _dtype: Option<IsqType>,
56        _device: Device,
57        _n_quantized: &AtomicUsize,
58        _imatrix_weight: Option<Vec<f32>>,
59        _guard: QuantizeOntoGuard,
60    ) -> Result<Arc<dyn QuantMethod>> {
61        todo!()
62    }
63}
64
65impl QuantizedSerde for GptqLayer {
66    fn name(&self) -> &'static str {
67        "gptq"
68    }
69}
70
71macro_rules! pack_factor {
72    ($bits:expr) => {
73        32 / $bits
74    };
75}
76
77pub fn gptq_linear(
78    in_dim: usize,
79    out_dim: usize,
80    config: &QuantizedConfig,
81    vb: ShardedVarBuilder,
82) -> Result<Arc<dyn QuantMethod>> {
83    let QuantizedConfig::GptqAwq {
84        bits,
85        group_size,
86        checkpoint_format: _,
87        is_awq,
88    } = config
89    else {
90        candle_core::bail!("Unexpected quantization config.")
91    };
92
93    let is_awq = *is_awq;
94    // Handle the case where we actually have an unquantized
95    if vb.contains_tensor("weight") {
96        return crate::linear_b(in_dim, out_dim, false, &None, vb);
97    }
98
99    // Handle the case where the layer is dummy (no tensors)
100    if !vb.contains_tensor("qweight")
101        || !vb.contains_tensor("qzeros")
102        || !vb.contains_tensor("scales")
103        || !is_awq && !vb.contains_tensor("g_idx")
104    {
105        let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
106        return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
107    }
108
109    let qw_shape = if !is_awq {
110        //quantized gptq (k/pack_factor, n) format
111        (in_dim / pack_factor!(bits), out_dim)
112    } else {
113        //quantized awq (k, n/pack_factor) format
114        (in_dim, out_dim / pack_factor!(bits))
115    };
116
117    let qweight = vb.get_with_hints_dtype(qw_shape, "qweight", Default::default(), DType::I32)?;
118    let scale_and_zero_size = in_dim / group_size;
119    let qzeros = vb.get_with_hints_dtype(
120        (scale_and_zero_size, out_dim / pack_factor!(bits)),
121        "qzeros",
122        Default::default(),
123        DType::I32,
124    )?;
125    let g_idx = if is_awq {
126        None
127    } else {
128        Some(vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?)
129    };
130    let scales = vb.get_with_hints_dtype(
131        (scale_and_zero_size, out_dim),
132        "scales",
133        Default::default(),
134        DType::F16,
135    )?;
136    let bias = if vb.contains_tensor("bias") {
137        Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
138    } else {
139        None
140    };
141
142    let config = QuantMethodConfig::GptqAwq {
143        bits: *bits as i32,
144        use_exllama: false,
145        q_weight: qweight,
146        qzeros: Some(qzeros),
147        scales,
148        g_idx,
149        bias,
150        workspace: None,
151        is_marlin: false,
152        is_awq,
153    };
154    Ok(Arc::new(GptqLayer::new(config)?))
155}