1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13 generate_isq, generate_isq_imatrix,
14 hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22 w: Tensor,
23 b: Option<Tensor>,
24 stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::GptqAwq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::PerTensorFP8 { .. }
41 | QuantMethodConfig::Afq { .. }
42 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
43 QuantMethodConfig::Unquantized(l) => Ok(Self {
44 w: l.weight().clone(),
45 b: l.bias().cloned(),
46 stats: None,
47 }),
48 }
49 }
50
51 fn dequantize_w(&self) -> Result<Tensor> {
52 Ok(self.w.clone())
53 }
54
55 fn forward(&self, a: &Tensor) -> Result<Tensor> {
56 maybe_init_cublas_lt_wrapper(a.device().clone());
58
59 #[cfg(feature = "cuda")]
61 if crate::gemv::should_use_gemv(a, &self.w) {
62 return crate::gemv::gemv(a, &self.w, self.b.as_ref());
63 }
64
65 let w = match *a.dims() {
66 [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
67 [bsize, _, _] => self.w.broadcast_left(bsize)?,
68 _ => self.w.clone(),
69 };
70
71 if let Some(stats) = &self.stats {
72 stats.process(a)?;
73 }
74
75 if let Some(b) = self.b.as_ref() {
76 let mut tgt_shape = a.dims().to_vec();
77 tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
78 let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
79
80 match a.device().location() {
81 DeviceLocation::Cuda { .. } => {
82 if let (Device::Cuda(_), Some(cublaslt)) =
84 (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
85 {
86 cublaslt
87 .batch_matmul(
88 a,
89 &w,
90 Some(&b.t()?.contiguous()?),
91 None,
92 Some(1.0),
93 None,
94 None,
95 )?
96 .t()
97 } else {
98 let matmul_result = a.matmul(&w.t()?)?;
99 matmul_result.broadcast_add(&b)
100 }
101 }
102 DeviceLocation::Metal { .. } => {
103 let matmul_result = a.matmul(&w.t()?)?;
104 matmul_result.broadcast_add(&b)
105 }
106 DeviceLocation::Cpu => {
107 #[cfg(feature = "accelerate")]
108 {
109 let original_dtype = a.dtype();
110 let a_f32 = a.to_dtype(DType::F32)?;
111 let w_f32 = w.t()?.to_dtype(DType::F32)?;
112 let b_f32 = b.to_dtype(DType::F32)?;
113 let matmul_result = a_f32.matmul(&w_f32)?;
114 matmul_result
115 .broadcast_add(&b_f32)?
116 .to_dtype(original_dtype)
117 }
118 #[cfg(not(feature = "accelerate"))]
119 {
120 let matmul_result = a.matmul(&w.t()?)?;
121 matmul_result.broadcast_add(&b)
122 }
123 }
124 }
125 } else if let (Device::Cuda(_), Some(cublaslt)) =
126 (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
127 {
128 if a.rank() >= 3 && w.rank() >= 3 {
130 cublaslt
131 .batch_matmul(a, &w, None, None, None, None, None)?
132 .t()
133 } else {
134 MatMul.matmul(a, &w.t()?)
135 }
136 } else {
137 MatMul.matmul(a, &w.t()?)
138 }
139 }
140
141 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
142 let w = &self.w;
151 let (_num_experts, out_features, _in_features) = w.dims3()?;
152
153 match a.dims() {
154 &[b_size, seq_len, 1, 1, hidden_dim] => {
156 let (_b, _s, num_experts_per_tok) = indices.dims3()?;
157 let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
159
160 let selected_w = w.index_select(&flat_indices, 0)?;
162
163 let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
165
166 let a_expanded = a_flat
169 .unsqueeze(1)?
170 .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
171 .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
172
173 let result = a_expanded
175 .unsqueeze(1)?
176 .matmul(&selected_w.transpose(1, 2)?)?
177 .squeeze(1)?;
178
179 result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
181 }
182 &[num_tokens, 1, hidden_dim] => {
184 let (_, num_experts_per_tok) = indices.dims2()?;
185
186 let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
188
189 let selected_w = w.index_select(&flat_indices, 0)?;
191
192 let a_expanded = a
194 .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
195 .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
196
197 let result = a_expanded
199 .unsqueeze(1)?
200 .matmul(&selected_w.transpose(1, 2)?)?
201 .squeeze(1)?;
202
203 result.reshape((num_tokens, num_experts_per_tok, out_features))
205 }
206 dims => {
207 candle_core::bail!(
208 "UnquantLinear::gather_forward: unsupported input shape {:?}",
209 dims
210 );
211 }
212 }
213 }
214
215 fn quantized_act_type(&self) -> Option<DType> {
216 None
217 }
218
219 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
220 Ok(Arc::new(Self {
221 w: (&self.w + delta)?,
222 b: self.b.clone(),
223 stats: self.stats.clone(),
224 }))
225 }
226
227 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
228 (self.w.dtype(), self.w.device().clone())
229 }
230
231 fn apply_isq(
232 self: Arc<Self>,
233 dtype: Option<IsqType>,
234 device: Device,
235 n_quantized: &AtomicUsize,
236 imatrix_weight: Option<Vec<f32>>,
237 guard: QuantizeOntoGuard,
238 ) -> Result<Arc<dyn QuantMethod>> {
239 match dtype {
240 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
242 let _acquired_quantize_guard = guard.acquire(&device);
243 if imatrix_weight.is_some() {
244 candle_core::bail!("HQQ does not support imatrix.");
246 }
247
248 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
249 let bits = match dtype.unwrap() {
250 IsqType::HQQ8 => HqqBits::Eight,
251 IsqType::HQQ4 => HqqBits::Four,
252 _ => unreachable!(),
256 };
257 let cfg = HqqConfig {
258 bits,
259 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
260 axis: HqqAxis::Zero,
261 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
262 round_zeros: false,
263 channel_wise: true,
264 };
265 let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
266 if let Some(bias) = &self.b {
267 let bias = bias
268 .to_device(&device)?
269 .to_dtype(res.dtype_and_device().0)?;
270 Ok(Arc::new(res.with_bias(bias)))
271 } else {
272 Ok(Arc::new(res))
273 }
274 }
275 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
276 let _acquired_quantize_guard = guard.acquire(&device);
277 if imatrix_weight.is_some() {
278 candle_core::bail!("AFQ does not support imatrix.");
280 }
281
282 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
283 let bits = match dtype.unwrap() {
284 IsqType::AFQ8 => AfqBits::Eight,
285 IsqType::AFQ6 => AfqBits::Six,
286 IsqType::AFQ4 => AfqBits::Four,
287 IsqType::AFQ3 => AfqBits::Three,
288 IsqType::AFQ2 => AfqBits::Two,
289 _ => unreachable!(),
290 };
291
292 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
293 weight: self.w.to_device(&device)?,
294 bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
295 bits,
296 group_size: AfqGroupSize::default(),
297 })?))
298 }
299 Some(
300 IsqType::Q2K
301 | IsqType::Q3K
302 | IsqType::Q4K
303 | IsqType::Q4_0
304 | IsqType::Q4_1
305 | IsqType::Q5K
306 | IsqType::Q5_0
307 | IsqType::Q5_1
308 | IsqType::Q6K
309 | IsqType::Q8K
310 | IsqType::Q8_0
311 | IsqType::Q8_1,
312 ) => {
313 let dtype: GgmlDType = dtype.unwrap().try_into()?;
314 let res = if let Some(imatrix_weight) = imatrix_weight {
315 generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
316 } else {
317 generate_isq!(self.w, device, dtype, n_quantized, guard)
318 };
319 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
320 q_weight: res,
321 b: self
322 .b
323 .as_ref()
324 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
325 })?))
326 }
327 Some(IsqType::F8E4M3) => {
328 let _acquired_quantize_guard = guard.acquire(&device);
329 if imatrix_weight.is_some() {
330 candle_core::bail!("F8E4M3 does not support imatrix.");
332 }
333
334 let w = self.w.to_device(&device)?;
335 let b = if let Some(b) = &self.b {
336 Some(b.to_device(&device)?)
337 } else {
338 None
339 };
340 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
341 lin: Linear::new(w, b),
342 dtype: DType::F8E4M3,
343 })?))
344 }
345 None => {
346 let _acquired_quantize_guard = guard.acquire(&device);
347 let w = self.w.to_device(&device)?;
350 let b = if let Some(b) = &self.b {
351 Some(b.to_device(&device)?)
352 } else {
353 None
354 };
355 Ok(Arc::new(UnquantLinear::new(
356 QuantMethodConfig::Unquantized(Linear::new(w, b)),
357 )?))
358 }
359 }
360 }
361
362 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
363 Some((self.w.clone(), self.b.clone()))
364 }
365
366 fn begin_track_stats(&mut self) -> Result<()> {
367 self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
368 Ok(())
369 }
370
371 fn end_track_stats(&self) -> Result<Tensor> {
372 if let Some(stats) = &self.stats {
373 let imatrix = stats.compute_imatrix()?;
374 stats.clear()?;
375 Ok(imatrix)
376 } else {
377 candle_core::bail!("`{}` does not support tracking stats.", self.name())
378 }
379 }
380}
381
382impl QuantizedSerde for UnquantLinear {
397 fn isq_serde_supported(&self) -> bool {
398 true
399 }
400 fn name(&self) -> &'static str {
401 "unquant-linear"
402 }
403 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
404 self.serialize_with_bias(self.b.clone())
405 }
406 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
407 let mut buffer = Vec::new();
408
409 buffer.extend(&UQFF_VERSION.to_le_bytes());
412
413 buffer.push(QuantizedSerdeType::Unquant as u8);
415
416 buffer.push(bias.is_some() as u8);
418
419 serialize_tensor(&mut buffer, &self.w)?;
421
422 if let Some(bias) = &bias {
423 serialize_tensor(&mut buffer, bias)?;
425 }
426
427 Ok(Cow::from(buffer))
428 }
429
430 fn deserialize(
431 data: Cow<[u8]>,
432 device: &Device,
433 _comm: &Arc<crate::Comm>,
434 guard: QuantizeOntoGuard,
435 ) -> Result<Arc<dyn QuantMethod>>
436 where
437 Self: Sized,
438 {
439 let mut buffer = Cursor::new(data);
440
441 let version = buffer.read_u32::<LittleEndian>()?;
442 if let Err(e) = version_is_compatible(version) {
443 return Err(candle_core::Error::wrap(e));
444 }
445
446 let isq_type = buffer.read_u8()? as usize;
447 if isq_type != QuantizedSerdeType::Unquant as usize {
448 candle_core::bail!(
449 "ISQ type ({isq_type}) doesn't match expected type {}",
450 QuantizedSerdeType::Unquant as usize
451 );
452 }
453
454 let has_bias = buffer.read_u8()? != 0;
455
456 let _acquired_load_guard = guard.acquire(device);
457 let w = deserialize_tensor(&mut buffer, device)?;
458
459 let b = if has_bias {
460 Some(deserialize_tensor(&mut buffer, device)?)
461 } else {
462 None
463 };
464
465 Ok(Arc::new(Self { w, b, stats: None }))
466 }
467 fn deserialize_ext_bias(
468 data: Cow<[u8]>,
469 device: &Device,
470 guard: QuantizeOntoGuard,
471 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
472 where
473 Self: Sized,
474 {
475 let mut buffer = Cursor::new(data);
476
477 let version = buffer.read_u32::<LittleEndian>()?;
478 if let Err(e) = version_is_compatible(version) {
479 return Err(candle_core::Error::wrap(e));
480 }
481
482 let isq_type = buffer.read_u8()? as usize;
483 if isq_type != QuantizedSerdeType::Unquant as usize {
484 candle_core::bail!(
485 "ISQ type ({isq_type}) doesn't match expected type {}",
486 QuantizedSerdeType::Unquant as usize
487 );
488 }
489
490 let has_bias = buffer.read_u8()? != 0;
491
492 let _acquired_load_guard = guard.acquire(device);
493 let w = deserialize_tensor(&mut buffer, device)?;
494
495 let b = if has_bias {
496 Some(deserialize_tensor(&mut buffer, device)?)
497 } else {
498 None
499 };
500
501 Ok((
502 Arc::new(Self {
503 w,
504 b: None,
505 stats: None,
506 }),
507 b,
508 ))
509 }
510}