1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13 generate_isq, generate_isq_imatrix,
14 hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22 w: Tensor,
23 b: Option<Tensor>,
24 stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::GptqAwq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::Afq { .. } => unreachable!(),
41 QuantMethodConfig::Unquantized(l) => Ok(Self {
42 w: l.weight().clone(),
43 b: l.bias().cloned(),
44 stats: None,
45 }),
46 }
47 }
48
49 fn dequantize_w(&self) -> Result<Tensor> {
50 Ok(self.w.clone())
51 }
52
53 fn forward(&self, a: &Tensor) -> Result<Tensor> {
54 maybe_init_cublas_lt_wrapper(a.device().clone());
56
57 let w = match *a.dims() {
58 [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
59 [bsize, _, _] => self.w.broadcast_left(bsize)?,
60 _ => self.w.clone(),
61 };
62
63 if let Some(stats) = &self.stats {
64 stats.process(a)?;
65 }
66
67 if let Some(b) = self.b.as_ref() {
68 let mut tgt_shape = a.dims().to_vec();
69 tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
70 let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
71
72 match a.device().location() {
73 DeviceLocation::Cuda { .. } => {
74 if let (Device::Cuda(_), Some(cublaslt)) =
76 (a.device(), CUBLASLT_CONTROLLER.get())
77 {
78 cublaslt
79 .batch_matmul(
80 a,
81 &w,
82 Some(&b.t()?.contiguous()?),
83 None,
84 Some(1.0),
85 None,
86 None,
87 )?
88 .t()
89 } else {
90 let mut out = b.contiguous()?;
91 a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
92 Ok(out)
93 }
94 }
95 DeviceLocation::Metal { .. } => {
96 let mut out = b.contiguous()?;
97 a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
98 Ok(out)
99 }
100 DeviceLocation::Cpu => {
101 #[cfg(feature = "accelerate")]
102 {
103 let original_dtype = a.dtype();
104 let mut out = b.contiguous()?.to_dtype(DType::F32)?;
105 a.to_dtype(DType::F32)?.matmul_with_alpha_beta(
106 &w.t()?.to_dtype(DType::F32)?,
107 &mut out,
108 None,
109 )?;
110 out.to_dtype(original_dtype)
111 }
112 #[cfg(not(feature = "accelerate"))]
113 {
114 let mut out = b.contiguous()?;
115 a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
116 Ok(out)
117 }
118 }
119 }
120 } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
121 cublaslt
122 .batch_matmul(a, &w, None, None, None, None, None)?
123 .t()
124 } else {
125 MatMul.matmul(a, &w.t()?)
126 }
127 }
128
129 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
130 let w = self.w.index_select(indices, 0)?;
132
133 a.broadcast_matmul(&w.t()?)
134 }
135
136 fn quantized_act_type(&self) -> Option<DType> {
137 None
138 }
139
140 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
141 Ok(Arc::new(Self {
142 w: (&self.w + delta)?,
143 b: self.b.clone(),
144 stats: self.stats.clone(),
145 }))
146 }
147
148 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
149 (self.w.dtype(), self.w.device().clone())
150 }
151
152 fn apply_isq(
153 self: Arc<Self>,
154 dtype: Option<IsqType>,
155 device: Device,
156 n_quantized: &AtomicUsize,
157 imatrix_weight: Option<Vec<f32>>,
158 guard: QuantizeOntoGuard,
159 ) -> Result<Arc<dyn QuantMethod>> {
160 match dtype {
161 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
163 let _acquired_quantize_guard = guard.acquire(&device);
164 if imatrix_weight.is_some() {
165 candle_core::bail!("HQQ does not support imatrix.");
167 }
168
169 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
170 let bits = match dtype.unwrap() {
171 IsqType::HQQ8 => HqqBits::Eight,
172 IsqType::HQQ4 => HqqBits::Four,
173 _ => unreachable!(),
177 };
178 let cfg = HqqConfig {
179 bits,
180 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
181 axis: HqqAxis::Zero,
182 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
183 round_zeros: false,
184 channel_wise: true,
185 };
186 let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
187 if let Some(bias) = &self.b {
188 let bias = bias
189 .to_device(&device)?
190 .to_dtype(res.dtype_and_device().0)?;
191 Ok(Arc::new(res.with_bias(bias)))
192 } else {
193 Ok(Arc::new(res))
194 }
195 }
196 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
197 let _acquired_quantize_guard = guard.acquire(&device);
198 if imatrix_weight.is_some() {
199 candle_core::bail!("AFQ does not support imatrix.");
201 }
202
203 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
204 let bits = match dtype.unwrap() {
205 IsqType::AFQ8 => AfqBits::Eight,
206 IsqType::AFQ6 => AfqBits::Six,
207 IsqType::AFQ4 => AfqBits::Four,
208 IsqType::AFQ3 => AfqBits::Three,
209 IsqType::AFQ2 => AfqBits::Two,
210 _ => unreachable!(),
211 };
212
213 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
214 weight: self.w.to_device(&device)?,
215 bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
216 bits,
217 group_size: AfqGroupSize::default(),
218 })?))
219 }
220 Some(
221 IsqType::Q2K
222 | IsqType::Q3K
223 | IsqType::Q4K
224 | IsqType::Q4_0
225 | IsqType::Q4_1
226 | IsqType::Q5K
227 | IsqType::Q5_0
228 | IsqType::Q5_1
229 | IsqType::Q6K
230 | IsqType::Q8K
231 | IsqType::Q8_0
232 | IsqType::Q8_1,
233 ) => {
234 let dtype: GgmlDType = dtype.unwrap().try_into()?;
235 let res = if let Some(imatrix_weight) = imatrix_weight {
236 generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
237 } else {
238 generate_isq!(self.w, device, dtype, n_quantized, guard)
239 };
240 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
241 q_weight: res,
242 b: self
243 .b
244 .as_ref()
245 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
246 })?))
247 }
248 Some(IsqType::F8E4M3) => {
249 let _acquired_quantize_guard = guard.acquire(&device);
250 if imatrix_weight.is_some() {
251 candle_core::bail!("F8E4M3 does not support imatrix.");
253 }
254
255 let w = self.w.to_device(&device)?;
256 let b = if let Some(b) = &self.b {
257 Some(b.to_device(&device)?)
258 } else {
259 None
260 };
261 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
262 lin: Linear::new(w, b),
263 dtype: DType::F8E4M3,
264 })?))
265 }
266 None => {
267 let _acquired_quantize_guard = guard.acquire(&device);
268 let w = self.w.to_device(&device)?;
271 let b = if let Some(b) = &self.b {
272 Some(b.to_device(&device)?)
273 } else {
274 None
275 };
276 Ok(Arc::new(UnquantLinear::new(
277 QuantMethodConfig::Unquantized(Linear::new(w, b)),
278 )?))
279 }
280 }
281 }
282
283 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
284 Some((self.w.clone(), self.b.clone()))
285 }
286
287 fn begin_track_stats(&mut self) -> Result<()> {
288 self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
289 Ok(())
290 }
291
292 fn end_track_stats(&self) -> Result<Tensor> {
293 if let Some(stats) = &self.stats {
294 let imatrix = stats.compute_imatrix()?;
295 stats.clear()?;
296 Ok(imatrix)
297 } else {
298 candle_core::bail!("`{}` does not support tracking stats.", self.name())
299 }
300 }
301}
302
303impl QuantizedSerde for UnquantLinear {
318 fn isq_serde_supported(&self) -> bool {
319 true
320 }
321 fn name(&self) -> &'static str {
322 "unquant-linear"
323 }
324 fn serialize(&self) -> Result<Cow<[u8]>> {
325 self.serialize_with_bias(self.b.clone())
326 }
327 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
328 let mut buffer = Vec::new();
329
330 buffer.extend(&UQFF_VERSION.to_le_bytes());
333
334 buffer.push(QuantizedSerdeType::Unquant as u8);
336
337 buffer.push(bias.is_some() as u8);
339
340 serialize_tensor(&mut buffer, &self.w)?;
342
343 if let Some(bias) = &bias {
344 serialize_tensor(&mut buffer, bias)?;
346 }
347
348 Ok(Cow::from(buffer))
349 }
350
351 fn deserialize(
352 data: Cow<[u8]>,
353 device: &Device,
354 _comm: &Arc<crate::Comm>,
355 guard: QuantizeOntoGuard,
356 ) -> Result<Arc<dyn QuantMethod>>
357 where
358 Self: Sized,
359 {
360 let mut buffer = Cursor::new(data);
361
362 let version = buffer.read_u32::<LittleEndian>()?;
363 if let Err(e) = version_is_compatible(version) {
364 return Err(candle_core::Error::wrap(e));
365 }
366
367 let isq_type = buffer.read_u8()? as usize;
368 if isq_type != QuantizedSerdeType::Unquant as usize {
369 candle_core::bail!(
370 "ISQ type ({isq_type}) doesn't match expected type {}",
371 QuantizedSerdeType::Unquant as usize
372 );
373 }
374
375 let has_bias = buffer.read_u8()? != 0;
376
377 let _acquired_load_guard = guard.acquire(device);
378 let w = deserialize_tensor(&mut buffer, device)?;
379
380 let b = if has_bias {
381 Some(deserialize_tensor(&mut buffer, device)?)
382 } else {
383 None
384 };
385
386 Ok(Arc::new(Self { w, b, stats: None }))
387 }
388 fn deserialize_ext_bias(
389 data: Cow<[u8]>,
390 device: &Device,
391 guard: QuantizeOntoGuard,
392 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
393 where
394 Self: Sized,
395 {
396 let mut buffer = Cursor::new(data);
397
398 let version = buffer.read_u32::<LittleEndian>()?;
399 if let Err(e) = version_is_compatible(version) {
400 return Err(candle_core::Error::wrap(e));
401 }
402
403 let isq_type = buffer.read_u8()? as usize;
404 if isq_type != QuantizedSerdeType::Unquant as usize {
405 candle_core::bail!(
406 "ISQ type ({isq_type}) doesn't match expected type {}",
407 QuantizedSerdeType::Unquant as usize
408 );
409 }
410
411 let has_bias = buffer.read_u8()? != 0;
412
413 let _acquired_load_guard = guard.acquire(device);
414 let w = deserialize_tensor(&mut buffer, device)?;
415
416 let b = if has_bias {
417 Some(deserialize_tensor(&mut buffer, device)?)
418 } else {
419 None
420 };
421
422 Ok((
423 Arc::new(Self {
424 w,
425 b: None,
426 stats: None,
427 }),
428 b,
429 ))
430 }
431}