mistralrs_quant/gptq/
gptq_cpu.rs1use crate::{
2 DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
3 QuantizedSerde, ShardedVarBuilder,
4};
5use candle_core::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12 fn new(method: QuantMethodConfig) -> Result<Self>
13 where
14 Self: Sized,
15 {
16 match method {
17 QuantMethodConfig::GptqAwq { .. } => {
18 candle_core::bail!("GPTQ is only supported on CUDA.")
19 }
20 QuantMethodConfig::Gguf { .. }
21 | QuantMethodConfig::Unquantized(_)
22 | QuantMethodConfig::Hqq { .. }
23 | QuantMethodConfig::Dummy
24 | QuantMethodConfig::FP8 { .. }
25 | QuantMethodConfig::Bnb { .. }
26 | QuantMethodConfig::BlockwiseFP8 { .. }
27 | QuantMethodConfig::Afq { .. }
28 | QuantMethodConfig::MXFP4 { .. } => {
29 unreachable!()
30 }
31 }
32 }
33
34 fn dequantize_w(&self) -> Result<Tensor> {
35 todo!()
36 }
37
38 fn forward(&self, _a: &Tensor) -> Result<Tensor> {
39 todo!()
40 }
41
42 fn quantized_act_type(&self) -> Option<DType> {
43 todo!()
44 }
45
46 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
47 todo!()
48 }
49
50 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
51 todo!()
52 }
53
54 fn apply_isq(
55 self: Arc<Self>,
56 _dtype: Option<IsqType>,
57 _device: Device,
58 _n_quantized: &AtomicUsize,
59 _imatrix_weight: Option<Vec<f32>>,
60 _guard: QuantizeOntoGuard,
61 ) -> Result<Arc<dyn QuantMethod>> {
62 todo!()
63 }
64}
65
66impl QuantizedSerde for GptqLayer {
67 fn name(&self) -> &'static str {
68 "gptq"
69 }
70}
71
72macro_rules! pack_factor {
73 ($bits:expr) => {
74 32 / $bits
75 };
76}
77
78pub fn gptq_linear(
79 in_dim: usize,
80 out_dim: usize,
81 config: &QuantizedConfig,
82 vb: ShardedVarBuilder,
83) -> Result<Arc<dyn QuantMethod>> {
84 let QuantizedConfig::GptqAwq {
85 bits,
86 group_size,
87 checkpoint_format: _,
88 is_awq,
89 } = config
90 else {
91 candle_core::bail!("Unexpected quantization config.")
92 };
93
94 let is_awq = *is_awq;
95 if vb.contains_tensor("weight") {
97 return crate::linear_b(in_dim, out_dim, false, &None, vb);
98 }
99
100 if !vb.contains_tensor("qweight")
102 || !vb.contains_tensor("qzeros")
103 || !vb.contains_tensor("scales")
104 || !is_awq && !vb.contains_tensor("g_idx")
105 {
106 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
107 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
108 }
109
110 let qw_shape = if !is_awq {
111 (in_dim / pack_factor!(bits), out_dim)
113 } else {
114 (in_dim, out_dim / pack_factor!(bits))
116 };
117
118 let qweight = vb.get_with_hints_dtype(qw_shape, "qweight", Default::default(), DType::I32)?;
119 let scale_and_zero_size = in_dim / group_size;
120 let qzeros = vb.get_with_hints_dtype(
121 (scale_and_zero_size, out_dim / pack_factor!(bits)),
122 "qzeros",
123 Default::default(),
124 DType::I32,
125 )?;
126 let g_idx = if is_awq {
127 None
128 } else {
129 Some(vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?)
130 };
131 let scales = vb.get_with_hints_dtype(
132 (scale_and_zero_size, out_dim),
133 "scales",
134 Default::default(),
135 DType::F16,
136 )?;
137 let bias = if vb.contains_tensor("bias") {
138 Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
139 } else {
140 None
141 };
142
143 let config = QuantMethodConfig::GptqAwq {
144 bits: *bits as i32,
145 use_exllama: false,
146 q_weight: qweight,
147 qzeros: Some(qzeros),
148 scales,
149 g_idx,
150 bias,
151 workspace: None,
152 is_marlin: false,
153 is_awq,
154 };
155 Ok(Arc::new(GptqLayer::new(config)?))
156}