1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13 generate_isq, generate_isq_imatrix,
14 hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22 w: Tensor,
23 b: Option<Tensor>,
24 stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::GptqAwq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::Afq { .. }
41 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
42 QuantMethodConfig::Unquantized(l) => Ok(Self {
43 w: l.weight().clone(),
44 b: l.bias().cloned(),
45 stats: None,
46 }),
47 }
48 }
49
50 fn dequantize_w(&self) -> Result<Tensor> {
51 Ok(self.w.clone())
52 }
53
54 fn forward(&self, a: &Tensor) -> Result<Tensor> {
55 maybe_init_cublas_lt_wrapper(a.device().clone());
57
58 let w = match *a.dims() {
59 [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
60 [bsize, _, _] => self.w.broadcast_left(bsize)?,
61 _ => self.w.clone(),
62 };
63
64 if let Some(stats) = &self.stats {
65 stats.process(a)?;
66 }
67
68 if let Some(b) = self.b.as_ref() {
69 let mut tgt_shape = a.dims().to_vec();
70 tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
71 let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
72
73 match a.device().location() {
74 DeviceLocation::Cuda { .. } => {
75 if let (Device::Cuda(_), Some(cublaslt)) =
77 (a.device(), CUBLASLT_CONTROLLER.get())
78 {
79 cublaslt
80 .batch_matmul(
81 a,
82 &w,
83 Some(&b.t()?.contiguous()?),
84 None,
85 Some(1.0),
86 None,
87 None,
88 )?
89 .t()
90 } else {
91 let matmul_result = a.matmul(&w.t()?)?;
92 matmul_result.broadcast_add(&b)
93 }
94 }
95 DeviceLocation::Metal { .. } => {
96 let matmul_result = a.matmul(&w.t()?)?;
97 matmul_result.broadcast_add(&b)
98 }
99 DeviceLocation::Cpu => {
100 #[cfg(feature = "accelerate")]
101 {
102 let original_dtype = a.dtype();
103 let a_f32 = a.to_dtype(DType::F32)?;
104 let w_f32 = w.t()?.to_dtype(DType::F32)?;
105 let b_f32 = b.to_dtype(DType::F32)?;
106 let matmul_result = a_f32.matmul(&w_f32)?;
107 matmul_result
108 .broadcast_add(&b_f32)?
109 .to_dtype(original_dtype)
110 }
111 #[cfg(not(feature = "accelerate"))]
112 {
113 let matmul_result = a.matmul(&w.t()?)?;
114 matmul_result.broadcast_add(&b)
115 }
116 }
117 }
118 } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
119 if a.rank() >= 3 && w.rank() >= 3 {
121 cublaslt
122 .batch_matmul(a, &w, None, None, None, None, None)?
123 .t()
124 } else {
125 MatMul.matmul(a, &w.t()?)
126 }
127 } else {
128 MatMul.matmul(a, &w.t()?)
129 }
130 }
131
132 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
133 let w = &self.w;
142 let (_num_experts, out_features, _in_features) = w.dims3()?;
143
144 match a.dims() {
145 &[b_size, seq_len, 1, 1, hidden_dim] => {
147 let (_b, _s, num_experts_per_tok) = indices.dims3()?;
148 let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
150
151 let selected_w = w.index_select(&flat_indices, 0)?;
153
154 let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
156
157 let a_expanded = a_flat
160 .unsqueeze(1)?
161 .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
162 .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
163
164 let result = a_expanded
166 .unsqueeze(1)?
167 .matmul(&selected_w.transpose(1, 2)?)?
168 .squeeze(1)?;
169
170 result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
172 }
173 &[num_tokens, 1, hidden_dim] => {
175 let (_, num_experts_per_tok) = indices.dims2()?;
176
177 let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
179
180 let selected_w = w.index_select(&flat_indices, 0)?;
182
183 let a_expanded = a
185 .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
186 .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
187
188 let result = a_expanded
190 .unsqueeze(1)?
191 .matmul(&selected_w.transpose(1, 2)?)?
192 .squeeze(1)?;
193
194 result.reshape((num_tokens, num_experts_per_tok, out_features))
196 }
197 dims => {
198 candle_core::bail!(
199 "UnquantLinear::gather_forward: unsupported input shape {:?}",
200 dims
201 );
202 }
203 }
204 }
205
206 fn quantized_act_type(&self) -> Option<DType> {
207 None
208 }
209
210 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
211 Ok(Arc::new(Self {
212 w: (&self.w + delta)?,
213 b: self.b.clone(),
214 stats: self.stats.clone(),
215 }))
216 }
217
218 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
219 (self.w.dtype(), self.w.device().clone())
220 }
221
222 fn apply_isq(
223 self: Arc<Self>,
224 dtype: Option<IsqType>,
225 device: Device,
226 n_quantized: &AtomicUsize,
227 imatrix_weight: Option<Vec<f32>>,
228 guard: QuantizeOntoGuard,
229 ) -> Result<Arc<dyn QuantMethod>> {
230 match dtype {
231 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
233 let _acquired_quantize_guard = guard.acquire(&device);
234 if imatrix_weight.is_some() {
235 candle_core::bail!("HQQ does not support imatrix.");
237 }
238
239 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
240 let bits = match dtype.unwrap() {
241 IsqType::HQQ8 => HqqBits::Eight,
242 IsqType::HQQ4 => HqqBits::Four,
243 _ => unreachable!(),
247 };
248 let cfg = HqqConfig {
249 bits,
250 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
251 axis: HqqAxis::Zero,
252 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
253 round_zeros: false,
254 channel_wise: true,
255 };
256 let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
257 if let Some(bias) = &self.b {
258 let bias = bias
259 .to_device(&device)?
260 .to_dtype(res.dtype_and_device().0)?;
261 Ok(Arc::new(res.with_bias(bias)))
262 } else {
263 Ok(Arc::new(res))
264 }
265 }
266 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
267 let _acquired_quantize_guard = guard.acquire(&device);
268 if imatrix_weight.is_some() {
269 candle_core::bail!("AFQ does not support imatrix.");
271 }
272
273 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
274 let bits = match dtype.unwrap() {
275 IsqType::AFQ8 => AfqBits::Eight,
276 IsqType::AFQ6 => AfqBits::Six,
277 IsqType::AFQ4 => AfqBits::Four,
278 IsqType::AFQ3 => AfqBits::Three,
279 IsqType::AFQ2 => AfqBits::Two,
280 _ => unreachable!(),
281 };
282
283 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
284 weight: self.w.to_device(&device)?,
285 bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
286 bits,
287 group_size: AfqGroupSize::default(),
288 })?))
289 }
290 Some(
291 IsqType::Q2K
292 | IsqType::Q3K
293 | IsqType::Q4K
294 | IsqType::Q4_0
295 | IsqType::Q4_1
296 | IsqType::Q5K
297 | IsqType::Q5_0
298 | IsqType::Q5_1
299 | IsqType::Q6K
300 | IsqType::Q8K
301 | IsqType::Q8_0
302 | IsqType::Q8_1,
303 ) => {
304 let dtype: GgmlDType = dtype.unwrap().try_into()?;
305 let res = if let Some(imatrix_weight) = imatrix_weight {
306 generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
307 } else {
308 generate_isq!(self.w, device, dtype, n_quantized, guard)
309 };
310 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
311 q_weight: res,
312 b: self
313 .b
314 .as_ref()
315 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
316 })?))
317 }
318 Some(IsqType::F8E4M3) => {
319 let _acquired_quantize_guard = guard.acquire(&device);
320 if imatrix_weight.is_some() {
321 candle_core::bail!("F8E4M3 does not support imatrix.");
323 }
324
325 let w = self.w.to_device(&device)?;
326 let b = if let Some(b) = &self.b {
327 Some(b.to_device(&device)?)
328 } else {
329 None
330 };
331 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
332 lin: Linear::new(w, b),
333 dtype: DType::F8E4M3,
334 })?))
335 }
336 None => {
337 let _acquired_quantize_guard = guard.acquire(&device);
338 let w = self.w.to_device(&device)?;
341 let b = if let Some(b) = &self.b {
342 Some(b.to_device(&device)?)
343 } else {
344 None
345 };
346 Ok(Arc::new(UnquantLinear::new(
347 QuantMethodConfig::Unquantized(Linear::new(w, b)),
348 )?))
349 }
350 }
351 }
352
353 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
354 Some((self.w.clone(), self.b.clone()))
355 }
356
357 fn begin_track_stats(&mut self) -> Result<()> {
358 self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
359 Ok(())
360 }
361
362 fn end_track_stats(&self) -> Result<Tensor> {
363 if let Some(stats) = &self.stats {
364 let imatrix = stats.compute_imatrix()?;
365 stats.clear()?;
366 Ok(imatrix)
367 } else {
368 candle_core::bail!("`{}` does not support tracking stats.", self.name())
369 }
370 }
371}
372
373impl QuantizedSerde for UnquantLinear {
388 fn isq_serde_supported(&self) -> bool {
389 true
390 }
391 fn name(&self) -> &'static str {
392 "unquant-linear"
393 }
394 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
395 self.serialize_with_bias(self.b.clone())
396 }
397 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
398 let mut buffer = Vec::new();
399
400 buffer.extend(&UQFF_VERSION.to_le_bytes());
403
404 buffer.push(QuantizedSerdeType::Unquant as u8);
406
407 buffer.push(bias.is_some() as u8);
409
410 serialize_tensor(&mut buffer, &self.w)?;
412
413 if let Some(bias) = &bias {
414 serialize_tensor(&mut buffer, bias)?;
416 }
417
418 Ok(Cow::from(buffer))
419 }
420
421 fn deserialize(
422 data: Cow<[u8]>,
423 device: &Device,
424 _comm: &Arc<crate::Comm>,
425 guard: QuantizeOntoGuard,
426 ) -> Result<Arc<dyn QuantMethod>>
427 where
428 Self: Sized,
429 {
430 let mut buffer = Cursor::new(data);
431
432 let version = buffer.read_u32::<LittleEndian>()?;
433 if let Err(e) = version_is_compatible(version) {
434 return Err(candle_core::Error::wrap(e));
435 }
436
437 let isq_type = buffer.read_u8()? as usize;
438 if isq_type != QuantizedSerdeType::Unquant as usize {
439 candle_core::bail!(
440 "ISQ type ({isq_type}) doesn't match expected type {}",
441 QuantizedSerdeType::Unquant as usize
442 );
443 }
444
445 let has_bias = buffer.read_u8()? != 0;
446
447 let _acquired_load_guard = guard.acquire(device);
448 let w = deserialize_tensor(&mut buffer, device)?;
449
450 let b = if has_bias {
451 Some(deserialize_tensor(&mut buffer, device)?)
452 } else {
453 None
454 };
455
456 Ok(Arc::new(Self { w, b, stats: None }))
457 }
458 fn deserialize_ext_bias(
459 data: Cow<[u8]>,
460 device: &Device,
461 guard: QuantizeOntoGuard,
462 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
463 where
464 Self: Sized,
465 {
466 let mut buffer = Cursor::new(data);
467
468 let version = buffer.read_u32::<LittleEndian>()?;
469 if let Err(e) = version_is_compatible(version) {
470 return Err(candle_core::Error::wrap(e));
471 }
472
473 let isq_type = buffer.read_u8()? as usize;
474 if isq_type != QuantizedSerdeType::Unquant as usize {
475 candle_core::bail!(
476 "ISQ type ({isq_type}) doesn't match expected type {}",
477 QuantizedSerdeType::Unquant as usize
478 );
479 }
480
481 let has_bias = buffer.read_u8()? != 0;
482
483 let _acquired_load_guard = guard.acquire(device);
484 let w = deserialize_tensor(&mut buffer, device)?;
485
486 let b = if has_bias {
487 Some(deserialize_tensor(&mut buffer, device)?)
488 } else {
489 None
490 };
491
492 Ok((
493 Arc::new(Self {
494 w,
495 b: None,
496 stats: None,
497 }),
498 b,
499 ))
500 }
501}