mistralrs_quant/gptq/
gptq_cpu.rs1use crate::{
2 DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
3 QuantizedSerde, ShardedVarBuilder,
4};
5use candle_core::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12 fn new(method: QuantMethodConfig) -> Result<Self>
13 where
14 Self: Sized,
15 {
16 match method {
17 QuantMethodConfig::GptqAwq { .. } => {
18 candle_core::bail!("GPTQ is only supported on CUDA.")
19 }
20 QuantMethodConfig::Gguf { .. }
21 | QuantMethodConfig::Unquantized(_)
22 | QuantMethodConfig::Hqq { .. }
23 | QuantMethodConfig::Dummy
24 | QuantMethodConfig::FP8 { .. }
25 | QuantMethodConfig::Bnb { .. }
26 | QuantMethodConfig::BlockwiseFP8 { .. }
27 | QuantMethodConfig::Afq { .. } => {
28 unreachable!()
29 }
30 }
31 }
32
33 fn dequantize_w(&self) -> Result<Tensor> {
34 todo!()
35 }
36
37 fn forward(&self, _a: &Tensor) -> Result<Tensor> {
38 todo!()
39 }
40
41 fn quantized_act_type(&self) -> Option<DType> {
42 todo!()
43 }
44
45 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
46 todo!()
47 }
48
49 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
50 todo!()
51 }
52
53 fn apply_isq(
54 self: Arc<Self>,
55 _dtype: Option<IsqType>,
56 _device: Device,
57 _n_quantized: &AtomicUsize,
58 _imatrix_weight: Option<Vec<f32>>,
59 _guard: QuantizeOntoGuard,
60 ) -> Result<Arc<dyn QuantMethod>> {
61 todo!()
62 }
63}
64
65impl QuantizedSerde for GptqLayer {
66 fn name(&self) -> &'static str {
67 "gptq"
68 }
69}
70
71macro_rules! pack_factor {
72 ($bits:expr) => {
73 32 / $bits
74 };
75}
76
77pub fn gptq_linear(
78 in_dim: usize,
79 out_dim: usize,
80 config: &QuantizedConfig,
81 vb: ShardedVarBuilder,
82) -> Result<Arc<dyn QuantMethod>> {
83 let QuantizedConfig::GptqAwq {
84 bits,
85 group_size,
86 checkpoint_format: _,
87 is_awq,
88 } = config
89 else {
90 candle_core::bail!("Unexpected quantization config.")
91 };
92
93 let is_awq = *is_awq;
94 if vb.contains_tensor("weight") {
96 return crate::linear_b(in_dim, out_dim, false, &None, vb);
97 }
98
99 if !vb.contains_tensor("qweight")
101 || !vb.contains_tensor("qzeros")
102 || !vb.contains_tensor("scales")
103 || !is_awq && !vb.contains_tensor("g_idx")
104 {
105 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
106 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
107 }
108
109 let qw_shape = if !is_awq {
110 (in_dim / pack_factor!(bits), out_dim)
112 } else {
113 (in_dim, out_dim / pack_factor!(bits))
115 };
116
117 let qweight = vb.get_with_hints_dtype(qw_shape, "qweight", Default::default(), DType::I32)?;
118 let scale_and_zero_size = in_dim / group_size;
119 let qzeros = vb.get_with_hints_dtype(
120 (scale_and_zero_size, out_dim / pack_factor!(bits)),
121 "qzeros",
122 Default::default(),
123 DType::I32,
124 )?;
125 let g_idx = if is_awq {
126 None
127 } else {
128 Some(vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?)
129 };
130 let scales = vb.get_with_hints_dtype(
131 (scale_and_zero_size, out_dim),
132 "scales",
133 Default::default(),
134 DType::F16,
135 )?;
136 let bias = if vb.contains_tensor("bias") {
137 Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
138 } else {
139 None
140 };
141
142 let config = QuantMethodConfig::GptqAwq {
143 bits: *bits as i32,
144 use_exllama: false,
145 q_weight: qweight,
146 qzeros: Some(qzeros),
147 scales,
148 g_idx,
149 bias,
150 workspace: None,
151 is_marlin: false,
152 is_awq,
153 };
154 Ok(Arc::new(GptqLayer::new(config)?))
155}