mistralrs_quant/gptq/
gptq_cpu.rs1use crate::{
2 DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
3 QuantizedSerde, ShardedVarBuilder,
4};
5use candle_core::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12 fn new(method: QuantMethodConfig) -> Result<Self>
13 where
14 Self: Sized,
15 {
16 match method {
17 QuantMethodConfig::Gptq { .. } => candle_core::bail!("GPTQ is only supported on CUDA."),
18 QuantMethodConfig::Gguf { .. }
19 | QuantMethodConfig::Unquantized(_)
20 | QuantMethodConfig::Hqq { .. }
21 | QuantMethodConfig::Dummy
22 | QuantMethodConfig::FP8 { .. }
23 | QuantMethodConfig::Bnb { .. }
24 | QuantMethodConfig::BlockwiseFP8 { .. }
25 | QuantMethodConfig::Afq { .. } => {
26 unreachable!()
27 }
28 }
29 }
30
31 fn dequantize_w(&self) -> Result<Tensor> {
32 todo!()
33 }
34
35 fn forward(&self, _a: &Tensor) -> Result<Tensor> {
36 todo!()
37 }
38
39 fn quantized_act_type(&self) -> Option<DType> {
40 todo!()
41 }
42
43 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
44 todo!()
45 }
46
47 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
48 todo!()
49 }
50
51 fn apply_isq(
52 self: Arc<Self>,
53 _dtype: Option<IsqType>,
54 _device: Device,
55 _n_quantized: &AtomicUsize,
56 _imatrix_weight: Option<Vec<f32>>,
57 _guard: QuantizeOntoGuard,
58 ) -> Result<Arc<dyn QuantMethod>> {
59 todo!()
60 }
61}
62
63impl QuantizedSerde for GptqLayer {
64 fn name(&self) -> &'static str {
65 "gptq"
66 }
67}
68
69macro_rules! pack_factor {
70 ($bits:expr) => {
71 32 / $bits
72 };
73}
74
75pub fn gptq_linear(
76 in_dim: usize,
77 out_dim: usize,
78 config: &QuantizedConfig,
79 vb: ShardedVarBuilder,
80) -> Result<Arc<dyn QuantMethod>> {
81 let QuantizedConfig::Gptq {
82 bits,
83 group_size,
84 checkpoint_format: _,
85 } = config
86 else {
87 candle_core::bail!("Unexpected quantization config.")
88 };
89
90 if !(vb.contains_tensor("qweight")
92 && vb.contains_tensor("qzeros")
93 && vb.contains_tensor("g_idx")
94 && vb.contains_tensor("scales"))
95 {
96 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
97 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
98 }
99
100 let qweight = vb.get_with_hints_dtype(
101 (in_dim / pack_factor!(bits), out_dim),
102 "qweight",
103 Default::default(),
104 DType::I32,
105 )?;
106 let scale_and_zero_size = in_dim / group_size;
107 let qzeros = vb.get_with_hints_dtype(
108 (scale_and_zero_size, out_dim / pack_factor!(bits)),
109 "qzeros",
110 Default::default(),
111 DType::I32,
112 )?;
113 let g_idx = vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?;
114 let scales = vb.get_with_hints_dtype(
115 (scale_and_zero_size, out_dim),
116 "scales",
117 Default::default(),
118 DType::F16,
119 )?;
120 let bias = if vb.contains_tensor("bias") {
121 Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
122 } else {
123 None
124 };
125
126 let config = QuantMethodConfig::Gptq {
127 bits: *bits as i32,
128 use_exllama: false,
129 q_weight: qweight,
130 gptq_qzeros: Some(qzeros),
131 gptq_scales: scales,
132 g_idx: Some(g_idx),
133 bias,
134 workspace: None,
135 is_marlin: false,
136 };
137 Ok(Arc::new(GptqLayer::new(config)?))
138}