mistralrs_quant/gptq/
gptq_cpu.rs

1use crate::{
2    DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
3    QuantizedSerde, ShardedVarBuilder,
4};
5use candle_core::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12    fn new(method: QuantMethodConfig) -> Result<Self>
13    where
14        Self: Sized,
15    {
16        match method {
17            QuantMethodConfig::GptqAwq { .. } => {
18                candle_core::bail!("GPTQ is only supported on CUDA.")
19            }
20            QuantMethodConfig::Gguf { .. }
21            | QuantMethodConfig::Unquantized(_)
22            | QuantMethodConfig::Hqq { .. }
23            | QuantMethodConfig::Dummy
24            | QuantMethodConfig::FP8 { .. }
25            | QuantMethodConfig::Bnb { .. }
26            | QuantMethodConfig::BlockwiseFP8 { .. }
27            | QuantMethodConfig::PerTensorFP8 { .. }
28            | QuantMethodConfig::Afq { .. }
29            | QuantMethodConfig::MXFP4 { .. } => {
30                unreachable!()
31            }
32        }
33    }
34
35    fn dequantize_w(&self) -> Result<Tensor> {
36        todo!()
37    }
38
39    fn forward(&self, _a: &Tensor) -> Result<Tensor> {
40        todo!()
41    }
42
43    fn quantized_act_type(&self) -> Option<DType> {
44        todo!()
45    }
46
47    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
48        todo!()
49    }
50
51    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
52        todo!()
53    }
54
55    fn apply_isq(
56        self: Arc<Self>,
57        _dtype: Option<IsqType>,
58        _device: Device,
59        _n_quantized: &AtomicUsize,
60        _imatrix_weight: Option<Vec<f32>>,
61        _guard: QuantizeOntoGuard,
62    ) -> Result<Arc<dyn QuantMethod>> {
63        todo!()
64    }
65}
66
67impl QuantizedSerde for GptqLayer {
68    fn name(&self) -> &'static str {
69        "gptq"
70    }
71}
72
73macro_rules! pack_factor {
74    ($bits:expr) => {
75        32 / $bits
76    };
77}
78
79pub fn gptq_linear(
80    in_dim: usize,
81    out_dim: usize,
82    config: &QuantizedConfig,
83    vb: ShardedVarBuilder,
84) -> Result<Arc<dyn QuantMethod>> {
85    let QuantizedConfig::GptqAwq {
86        bits,
87        group_size,
88        checkpoint_format: _,
89        is_awq,
90    } = config
91    else {
92        candle_core::bail!("Unexpected quantization config.")
93    };
94
95    let is_awq = *is_awq;
96    // Handle the case where we actually have an unquantized
97    if vb.contains_tensor("weight") {
98        return crate::linear_b(in_dim, out_dim, false, &None, vb);
99    }
100
101    // Handle the case where the layer is dummy (no tensors)
102    if !vb.contains_tensor("qweight")
103        || !vb.contains_tensor("qzeros")
104        || !vb.contains_tensor("scales")
105        || !is_awq && !vb.contains_tensor("g_idx")
106    {
107        let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
108        return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
109    }
110
111    let qw_shape = if !is_awq {
112        //quantized gptq (k/pack_factor, n) format
113        (in_dim / pack_factor!(bits), out_dim)
114    } else {
115        //quantized awq (k, n/pack_factor) format
116        (in_dim, out_dim / pack_factor!(bits))
117    };
118
119    let qweight = vb.get_with_hints_dtype(qw_shape, "qweight", Default::default(), DType::I32)?;
120    let scale_and_zero_size = in_dim / group_size;
121    let qzeros = vb.get_with_hints_dtype(
122        (scale_and_zero_size, out_dim / pack_factor!(bits)),
123        "qzeros",
124        Default::default(),
125        DType::I32,
126    )?;
127    let g_idx = if is_awq {
128        None
129    } else {
130        Some(vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?)
131    };
132    let scales = vb.get_with_hints_dtype(
133        (scale_and_zero_size, out_dim),
134        "scales",
135        Default::default(),
136        DType::F16,
137    )?;
138    let bias = if vb.contains_tensor("bias") {
139        Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
140    } else {
141        None
142    };
143
144    let config = QuantMethodConfig::GptqAwq {
145        bits: *bits as i32,
146        use_exllama: false,
147        q_weight: qweight,
148        qzeros: Some(qzeros),
149        scales,
150        g_idx,
151        bias,
152        workspace: None,
153        is_marlin: false,
154        is_awq,
155    };
156    Ok(Arc::new(GptqLayer::new(config)?))
157}