mistralrs_quant/gptq/
gptq_cpu.rs

1use crate::{
2    DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
3    QuantizedSerde, ShardedVarBuilder,
4};
5use candle_core::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12    fn new(method: QuantMethodConfig) -> Result<Self>
13    where
14        Self: Sized,
15    {
16        match method {
17            QuantMethodConfig::GptqAwq { .. } => {
18                candle_core::bail!("GPTQ is only supported on CUDA.")
19            }
20            QuantMethodConfig::Gguf { .. }
21            | QuantMethodConfig::Unquantized(_)
22            | QuantMethodConfig::Hqq { .. }
23            | QuantMethodConfig::Dummy
24            | QuantMethodConfig::FP8 { .. }
25            | QuantMethodConfig::Bnb { .. }
26            | QuantMethodConfig::BlockwiseFP8 { .. }
27            | QuantMethodConfig::Afq { .. }
28            | QuantMethodConfig::MXFP4 { .. } => {
29                unreachable!()
30            }
31        }
32    }
33
34    fn dequantize_w(&self) -> Result<Tensor> {
35        todo!()
36    }
37
38    fn forward(&self, _a: &Tensor) -> Result<Tensor> {
39        todo!()
40    }
41
42    fn quantized_act_type(&self) -> Option<DType> {
43        todo!()
44    }
45
46    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
47        todo!()
48    }
49
50    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
51        todo!()
52    }
53
54    fn apply_isq(
55        self: Arc<Self>,
56        _dtype: Option<IsqType>,
57        _device: Device,
58        _n_quantized: &AtomicUsize,
59        _imatrix_weight: Option<Vec<f32>>,
60        _guard: QuantizeOntoGuard,
61    ) -> Result<Arc<dyn QuantMethod>> {
62        todo!()
63    }
64}
65
66impl QuantizedSerde for GptqLayer {
67    fn name(&self) -> &'static str {
68        "gptq"
69    }
70}
71
72macro_rules! pack_factor {
73    ($bits:expr) => {
74        32 / $bits
75    };
76}
77
78pub fn gptq_linear(
79    in_dim: usize,
80    out_dim: usize,
81    config: &QuantizedConfig,
82    vb: ShardedVarBuilder,
83) -> Result<Arc<dyn QuantMethod>> {
84    let QuantizedConfig::GptqAwq {
85        bits,
86        group_size,
87        checkpoint_format: _,
88        is_awq,
89    } = config
90    else {
91        candle_core::bail!("Unexpected quantization config.")
92    };
93
94    let is_awq = *is_awq;
95    // Handle the case where we actually have an unquantized
96    if vb.contains_tensor("weight") {
97        return crate::linear_b(in_dim, out_dim, false, &None, vb);
98    }
99
100    // Handle the case where the layer is dummy (no tensors)
101    if !vb.contains_tensor("qweight")
102        || !vb.contains_tensor("qzeros")
103        || !vb.contains_tensor("scales")
104        || !is_awq && !vb.contains_tensor("g_idx")
105    {
106        let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
107        return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
108    }
109
110    let qw_shape = if !is_awq {
111        //quantized gptq (k/pack_factor, n) format
112        (in_dim / pack_factor!(bits), out_dim)
113    } else {
114        //quantized awq (k, n/pack_factor) format
115        (in_dim, out_dim / pack_factor!(bits))
116    };
117
118    let qweight = vb.get_with_hints_dtype(qw_shape, "qweight", Default::default(), DType::I32)?;
119    let scale_and_zero_size = in_dim / group_size;
120    let qzeros = vb.get_with_hints_dtype(
121        (scale_and_zero_size, out_dim / pack_factor!(bits)),
122        "qzeros",
123        Default::default(),
124        DType::I32,
125    )?;
126    let g_idx = if is_awq {
127        None
128    } else {
129        Some(vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?)
130    };
131    let scales = vb.get_with_hints_dtype(
132        (scale_and_zero_size, out_dim),
133        "scales",
134        Default::default(),
135        DType::F16,
136    )?;
137    let bias = if vb.contains_tensor("bias") {
138        Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
139    } else {
140        None
141    };
142
143    let config = QuantMethodConfig::GptqAwq {
144        bits: *bits as i32,
145        use_exllama: false,
146        q_weight: qweight,
147        qzeros: Some(qzeros),
148        scales,
149        g_idx,
150        bias,
151        workspace: None,
152        is_marlin: false,
153        is_awq,
154    };
155    Ok(Arc::new(GptqLayer::new(config)?))
156}