1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_HANDLE},
13 generate_isq, generate_isq_imatrix,
14 hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22 w: Tensor,
23 b: Option<Tensor>,
24 stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::Gptq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::Afq { .. } => unreachable!(),
41 QuantMethodConfig::Unquantized(l) => Ok(Self {
42 w: l.weight().clone(),
43 b: l.bias().cloned(),
44 stats: None,
45 }),
46 }
47 }
48
49 fn dequantize_w(&self) -> Result<Tensor> {
50 Ok(self.w.clone())
51 }
52
53 fn forward(&self, a: &Tensor) -> Result<Tensor> {
54 maybe_init_cublas_lt_wrapper(a.device().clone());
56
57 let w = match *a.dims() {
58 [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
59 [bsize, _, _] => self.w.broadcast_left(bsize)?,
60 _ => self.w.clone(),
61 };
62
63 if let Some(stats) = &self.stats {
64 stats.process(a)?;
65 }
66
67 if let Some(b) = self.b.as_ref() {
68 let mut tgt_shape = a.dims().to_vec();
69 tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
70 let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
71
72 match a.device().location() {
73 DeviceLocation::Cuda { .. } => {
74 if let (Device::Cuda(_), Some(cublaslt)) =
76 (a.device(), *CUBLASLT_HANDLE.lock().unwrap())
77 {
78 cublaslt
79 .batch_matmul(
80 a,
81 &w,
82 Some(&b.t()?.contiguous()?),
83 None,
84 Some(1.0),
85 None,
86 None,
87 )?
88 .t()
89 } else {
90 let mut out = b.contiguous()?;
91 a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
92 Ok(out)
93 }
94 }
95 DeviceLocation::Metal { .. } => {
96 let mut out = b.contiguous()?;
97 a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
98 Ok(out)
99 }
100 DeviceLocation::Cpu => {
101 #[cfg(feature = "accelerate")]
102 {
103 let original_dtype = a.dtype();
104 let mut out = b.contiguous()?.to_dtype(DType::F32)?;
105 a.to_dtype(DType::F32)?.matmul_with_alpha_beta(
106 &w.t()?.to_dtype(DType::F32)?,
107 &mut out,
108 None,
109 )?;
110 out.to_dtype(original_dtype)
111 }
112 #[cfg(not(feature = "accelerate"))]
113 {
114 let mut out = b.contiguous()?;
115 a.matmul_with_alpha_beta(&w.t()?, &mut out, None)?;
116 Ok(out)
117 }
118 }
119 }
120 } else if let (Device::Cuda(_), Some(cublaslt)) =
121 (a.device(), *CUBLASLT_HANDLE.lock().unwrap())
122 {
123 cublaslt
124 .batch_matmul(a, &w, None, None, None, None, None)?
125 .t()
126 } else {
127 MatMul.matmul(a, &w.t()?)
128 }
129 }
130
131 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
132 let w = self.w.index_select(indices, 0)?;
134
135 a.broadcast_matmul(&w.t()?)
136 }
137
138 fn quantized_act_type(&self) -> Option<DType> {
139 None
140 }
141
142 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
143 Ok(Arc::new(Self {
144 w: (&self.w + delta)?,
145 b: self.b.clone(),
146 stats: self.stats.clone(),
147 }))
148 }
149
150 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
151 (self.w.dtype(), self.w.device().clone())
152 }
153
154 fn apply_isq(
155 self: Arc<Self>,
156 dtype: Option<IsqType>,
157 device: Device,
158 n_quantized: &AtomicUsize,
159 imatrix_weight: Option<Vec<f32>>,
160 guard: QuantizeOntoGuard,
161 ) -> Result<Arc<dyn QuantMethod>> {
162 match dtype {
163 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
165 let _acquired_quantize_guard = guard.acquire();
166 if imatrix_weight.is_some() {
167 candle_core::bail!("HQQ does not support imatrix.");
169 }
170
171 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
172 let bits = match dtype.unwrap() {
173 IsqType::HQQ8 => HqqBits::Eight,
174 IsqType::HQQ4 => HqqBits::Four,
175 _ => unreachable!(),
179 };
180 let cfg = HqqConfig {
181 bits,
182 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
183 axis: HqqAxis::Zero,
184 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
185 round_zeros: false,
186 channel_wise: true,
187 };
188 let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
189 if let Some(bias) = &self.b {
190 let bias = bias
191 .to_device(&device)?
192 .to_dtype(res.dtype_and_device().0)?;
193 Ok(Arc::new(res.with_bias(bias)))
194 } else {
195 Ok(Arc::new(res))
196 }
197 }
198 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
199 let _acquired_quantize_guard = guard.acquire();
200 if imatrix_weight.is_some() {
201 candle_core::bail!("AFQ does not support imatrix.");
203 }
204
205 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
206 let bits = match dtype.unwrap() {
207 IsqType::AFQ8 => AfqBits::Eight,
208 IsqType::AFQ6 => AfqBits::Six,
209 IsqType::AFQ4 => AfqBits::Four,
210 IsqType::AFQ3 => AfqBits::Three,
211 IsqType::AFQ2 => AfqBits::Two,
212 _ => unreachable!(),
213 };
214
215 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
216 weight: self.w.to_device(&device)?,
217 bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
218 bits,
219 group_size: AfqGroupSize::default(),
220 })?))
221 }
222 Some(
223 IsqType::Q2K
224 | IsqType::Q3K
225 | IsqType::Q4K
226 | IsqType::Q4_0
227 | IsqType::Q4_1
228 | IsqType::Q5K
229 | IsqType::Q5_0
230 | IsqType::Q5_1
231 | IsqType::Q6K
232 | IsqType::Q8K
233 | IsqType::Q8_0
234 | IsqType::Q8_1,
235 ) => {
236 let dtype: GgmlDType = dtype.unwrap().try_into()?;
237 let res = if let Some(imatrix_weight) = imatrix_weight {
238 generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
239 } else {
240 generate_isq!(self.w, device, dtype, n_quantized, guard)
241 };
242 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
243 q_weight: res,
244 b: self
245 .b
246 .as_ref()
247 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
248 })?))
249 }
250 Some(IsqType::F8E4M3) => {
251 let _acquired_quantize_guard = guard.acquire();
252 if imatrix_weight.is_some() {
253 candle_core::bail!("F8E4M3 does not support imatrix.");
255 }
256
257 let w = self.w.to_device(&device)?;
258 let b = if let Some(b) = &self.b {
259 Some(b.to_device(&device)?)
260 } else {
261 None
262 };
263 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
264 lin: Linear::new(w, b),
265 dtype: DType::F8E4M3,
266 })?))
267 }
268 None => {
269 let _acquired_quantize_guard = guard.acquire();
270 let w = self.w.to_device(&device)?;
273 let b = if let Some(b) = &self.b {
274 Some(b.to_device(&device)?)
275 } else {
276 None
277 };
278 Ok(Arc::new(UnquantLinear::new(
279 QuantMethodConfig::Unquantized(Linear::new(w, b)),
280 )?))
281 }
282 }
283 }
284
285 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
286 Some((self.w.clone(), self.b.clone()))
287 }
288
289 fn begin_track_stats(&mut self) -> Result<()> {
290 self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
291 Ok(())
292 }
293
294 fn end_track_stats(&self) -> Result<Tensor> {
295 if let Some(stats) = &self.stats {
296 let imatrix = stats.compute_imatrix()?;
297 stats.clear()?;
298 Ok(imatrix)
299 } else {
300 candle_core::bail!("`{}` does not support tracking stats.", self.name())
301 }
302 }
303}
304
305impl QuantizedSerde for UnquantLinear {
320 fn isq_serde_supported(&self) -> bool {
321 true
322 }
323 fn name(&self) -> &'static str {
324 "unquant-linear"
325 }
326 fn serialize(&self) -> Result<Cow<[u8]>> {
327 self.serialize_with_bias(self.b.clone())
328 }
329 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
330 let mut buffer = Vec::new();
331
332 buffer.extend(&UQFF_VERSION.to_le_bytes());
335
336 buffer.push(QuantizedSerdeType::Unquant as u8);
338
339 buffer.push(bias.is_some() as u8);
341
342 serialize_tensor(&mut buffer, &self.w)?;
344
345 if let Some(bias) = &bias {
346 serialize_tensor(&mut buffer, bias)?;
348 }
349
350 Ok(Cow::from(buffer))
351 }
352
353 fn deserialize(
354 data: Cow<[u8]>,
355 device: &Device,
356 _comm: &Arc<crate::Comm>,
357 guard: QuantizeOntoGuard,
358 ) -> Result<Arc<dyn QuantMethod>>
359 where
360 Self: Sized,
361 {
362 let mut buffer = Cursor::new(data);
363
364 let version = buffer.read_u32::<LittleEndian>()?;
365 if let Err(e) = version_is_compatible(version) {
366 return Err(candle_core::Error::wrap(e));
367 }
368
369 let isq_type = buffer.read_u8()? as usize;
370 if isq_type != QuantizedSerdeType::Unquant as usize {
371 candle_core::bail!(
372 "ISQ type ({isq_type}) doesn't match expected type {}",
373 QuantizedSerdeType::Unquant as usize
374 );
375 }
376
377 let has_bias = buffer.read_u8()? != 0;
378
379 let _acquired_load_guard = guard.acquire();
380 let w = deserialize_tensor(&mut buffer, device)?;
381
382 let b = if has_bias {
383 Some(deserialize_tensor(&mut buffer, device)?)
384 } else {
385 None
386 };
387
388 Ok(Arc::new(Self { w, b, stats: None }))
389 }
390 fn deserialize_ext_bias(
391 data: Cow<[u8]>,
392 device: &Device,
393 guard: QuantizeOntoGuard,
394 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
395 where
396 Self: Sized,
397 {
398 let mut buffer = Cursor::new(data);
399
400 let version = buffer.read_u32::<LittleEndian>()?;
401 if let Err(e) = version_is_compatible(version) {
402 return Err(candle_core::Error::wrap(e));
403 }
404
405 let isq_type = buffer.read_u8()? as usize;
406 if isq_type != QuantizedSerdeType::Unquant as usize {
407 candle_core::bail!(
408 "ISQ type ({isq_type}) doesn't match expected type {}",
409 QuantizedSerdeType::Unquant as usize
410 );
411 }
412
413 let has_bias = buffer.read_u8()? != 0;
414
415 let _acquired_load_guard = guard.acquire();
416 let w = deserialize_tensor(&mut buffer, device)?;
417
418 let b = if has_bias {
419 Some(deserialize_tensor(&mut buffer, device)?)
420 } else {
421 None
422 };
423
424 Ok((
425 Arc::new(Self {
426 w,
427 b: None,
428 stats: None,
429 }),
430 b,
431 ))
432 }
433}