mistralrs_quant/gptq/
gptq_cpu.rs1use crate::{
2 DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
3 QuantizedSerde, ShardedVarBuilder,
4};
5use candle_core::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12 fn new(method: QuantMethodConfig) -> Result<Self>
13 where
14 Self: Sized,
15 {
16 match method {
17 QuantMethodConfig::GptqAwq { .. } => {
18 candle_core::bail!("GPTQ is only supported on CUDA.")
19 }
20 QuantMethodConfig::Gguf { .. }
21 | QuantMethodConfig::Unquantized(_)
22 | QuantMethodConfig::Hqq { .. }
23 | QuantMethodConfig::Dummy
24 | QuantMethodConfig::FP8 { .. }
25 | QuantMethodConfig::Bnb { .. }
26 | QuantMethodConfig::BlockwiseFP8 { .. }
27 | QuantMethodConfig::PerTensorFP8 { .. }
28 | QuantMethodConfig::Afq { .. }
29 | QuantMethodConfig::MXFP4 { .. } => {
30 unreachable!()
31 }
32 }
33 }
34
35 fn dequantize_w(&self) -> Result<Tensor> {
36 todo!()
37 }
38
39 fn forward(&self, _a: &Tensor) -> Result<Tensor> {
40 todo!()
41 }
42
43 fn quantized_act_type(&self) -> Option<DType> {
44 todo!()
45 }
46
47 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
48 todo!()
49 }
50
51 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
52 todo!()
53 }
54
55 fn apply_isq(
56 self: Arc<Self>,
57 _dtype: Option<IsqType>,
58 _device: Device,
59 _n_quantized: &AtomicUsize,
60 _imatrix_weight: Option<Vec<f32>>,
61 _guard: QuantizeOntoGuard,
62 ) -> Result<Arc<dyn QuantMethod>> {
63 todo!()
64 }
65}
66
67impl QuantizedSerde for GptqLayer {
68 fn name(&self) -> &'static str {
69 "gptq"
70 }
71}
72
73macro_rules! pack_factor {
74 ($bits:expr) => {
75 32 / $bits
76 };
77}
78
79pub fn gptq_linear(
80 in_dim: usize,
81 out_dim: usize,
82 config: &QuantizedConfig,
83 vb: ShardedVarBuilder,
84) -> Result<Arc<dyn QuantMethod>> {
85 let QuantizedConfig::GptqAwq {
86 bits,
87 group_size,
88 checkpoint_format: _,
89 is_awq,
90 } = config
91 else {
92 candle_core::bail!("Unexpected quantization config.")
93 };
94
95 let is_awq = *is_awq;
96 if vb.contains_tensor("weight") {
98 return crate::linear_b(in_dim, out_dim, false, &None, vb);
99 }
100
101 if !vb.contains_tensor("qweight")
103 || !vb.contains_tensor("qzeros")
104 || !vb.contains_tensor("scales")
105 || !is_awq && !vb.contains_tensor("g_idx")
106 {
107 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
108 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
109 }
110
111 let qw_shape = if !is_awq {
112 (in_dim / pack_factor!(bits), out_dim)
114 } else {
115 (in_dim, out_dim / pack_factor!(bits))
117 };
118
119 let qweight = vb.get_with_hints_dtype(qw_shape, "qweight", Default::default(), DType::I32)?;
120 let scale_and_zero_size = in_dim / group_size;
121 let qzeros = vb.get_with_hints_dtype(
122 (scale_and_zero_size, out_dim / pack_factor!(bits)),
123 "qzeros",
124 Default::default(),
125 DType::I32,
126 )?;
127 let g_idx = if is_awq {
128 None
129 } else {
130 Some(vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?)
131 };
132 let scales = vb.get_with_hints_dtype(
133 (scale_and_zero_size, out_dim),
134 "scales",
135 Default::default(),
136 DType::F16,
137 )?;
138 let bias = if vb.contains_tensor("bias") {
139 Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
140 } else {
141 None
142 };
143
144 let config = QuantMethodConfig::GptqAwq {
145 bits: *bits as i32,
146 use_exllama: false,
147 q_weight: qweight,
148 qzeros: Some(qzeros),
149 scales,
150 g_idx,
151 bias,
152 workspace: None,
153 is_marlin: false,
154 is_awq,
155 };
156 Ok(Arc::new(GptqLayer::new(config)?))
157}