1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13 generate_isq, generate_isq_imatrix,
14 hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType, MatMul,
17 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22 w: Tensor,
23 b: Option<Tensor>,
24 stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::GptqAwq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::Afq { .. }
41 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
42 QuantMethodConfig::Unquantized(l) => Ok(Self {
43 w: l.weight().clone(),
44 b: l.bias().cloned(),
45 stats: None,
46 }),
47 }
48 }
49
50 fn dequantize_w(&self) -> Result<Tensor> {
51 Ok(self.w.clone())
52 }
53
54 fn forward(&self, a: &Tensor) -> Result<Tensor> {
55 maybe_init_cublas_lt_wrapper(a.device().clone());
57
58 #[cfg(feature = "cuda")]
60 if crate::gemv::should_use_gemv(a, &self.w) {
61 return crate::gemv::gemv(a, &self.w, self.b.as_ref());
62 }
63
64 let w = match *a.dims() {
65 [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
66 [bsize, _, _] => self.w.broadcast_left(bsize)?,
67 _ => self.w.clone(),
68 };
69
70 if let Some(stats) = &self.stats {
71 stats.process(a)?;
72 }
73
74 if let Some(b) = self.b.as_ref() {
75 let mut tgt_shape = a.dims().to_vec();
76 tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
77 let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
78
79 match a.device().location() {
80 DeviceLocation::Cuda { .. } => {
81 if let (Device::Cuda(_), Some(cublaslt)) =
83 (a.device(), CUBLASLT_CONTROLLER.get())
84 {
85 cublaslt
86 .batch_matmul(
87 a,
88 &w,
89 Some(&b.t()?.contiguous()?),
90 None,
91 Some(1.0),
92 None,
93 None,
94 )?
95 .t()
96 } else {
97 let matmul_result = a.matmul(&w.t()?)?;
98 matmul_result.broadcast_add(&b)
99 }
100 }
101 DeviceLocation::Metal { .. } => {
102 let matmul_result = a.matmul(&w.t()?)?;
103 matmul_result.broadcast_add(&b)
104 }
105 DeviceLocation::Cpu => {
106 #[cfg(feature = "accelerate")]
107 {
108 let original_dtype = a.dtype();
109 let a_f32 = a.to_dtype(DType::F32)?;
110 let w_f32 = w.t()?.to_dtype(DType::F32)?;
111 let b_f32 = b.to_dtype(DType::F32)?;
112 let matmul_result = a_f32.matmul(&w_f32)?;
113 matmul_result
114 .broadcast_add(&b_f32)?
115 .to_dtype(original_dtype)
116 }
117 #[cfg(not(feature = "accelerate"))]
118 {
119 let matmul_result = a.matmul(&w.t()?)?;
120 matmul_result.broadcast_add(&b)
121 }
122 }
123 }
124 } else if let (Device::Cuda(_), Some(cublaslt)) = (a.device(), CUBLASLT_CONTROLLER.get()) {
125 if a.rank() >= 3 && w.rank() >= 3 {
127 cublaslt
128 .batch_matmul(a, &w, None, None, None, None, None)?
129 .t()
130 } else {
131 MatMul.matmul(a, &w.t()?)
132 }
133 } else {
134 MatMul.matmul(a, &w.t()?)
135 }
136 }
137
138 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
139 let w = &self.w;
148 let (_num_experts, out_features, _in_features) = w.dims3()?;
149
150 match a.dims() {
151 &[b_size, seq_len, 1, 1, hidden_dim] => {
153 let (_b, _s, num_experts_per_tok) = indices.dims3()?;
154 let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
156
157 let selected_w = w.index_select(&flat_indices, 0)?;
159
160 let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
162
163 let a_expanded = a_flat
166 .unsqueeze(1)?
167 .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
168 .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
169
170 let result = a_expanded
172 .unsqueeze(1)?
173 .matmul(&selected_w.transpose(1, 2)?)?
174 .squeeze(1)?;
175
176 result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
178 }
179 &[num_tokens, 1, hidden_dim] => {
181 let (_, num_experts_per_tok) = indices.dims2()?;
182
183 let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
185
186 let selected_w = w.index_select(&flat_indices, 0)?;
188
189 let a_expanded = a
191 .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
192 .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
193
194 let result = a_expanded
196 .unsqueeze(1)?
197 .matmul(&selected_w.transpose(1, 2)?)?
198 .squeeze(1)?;
199
200 result.reshape((num_tokens, num_experts_per_tok, out_features))
202 }
203 dims => {
204 candle_core::bail!(
205 "UnquantLinear::gather_forward: unsupported input shape {:?}",
206 dims
207 );
208 }
209 }
210 }
211
212 fn quantized_act_type(&self) -> Option<DType> {
213 None
214 }
215
216 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
217 Ok(Arc::new(Self {
218 w: (&self.w + delta)?,
219 b: self.b.clone(),
220 stats: self.stats.clone(),
221 }))
222 }
223
224 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
225 (self.w.dtype(), self.w.device().clone())
226 }
227
228 fn apply_isq(
229 self: Arc<Self>,
230 dtype: Option<IsqType>,
231 device: Device,
232 n_quantized: &AtomicUsize,
233 imatrix_weight: Option<Vec<f32>>,
234 guard: QuantizeOntoGuard,
235 ) -> Result<Arc<dyn QuantMethod>> {
236 match dtype {
237 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
239 let _acquired_quantize_guard = guard.acquire(&device);
240 if imatrix_weight.is_some() {
241 candle_core::bail!("HQQ does not support imatrix.");
243 }
244
245 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
246 let bits = match dtype.unwrap() {
247 IsqType::HQQ8 => HqqBits::Eight,
248 IsqType::HQQ4 => HqqBits::Four,
249 _ => unreachable!(),
253 };
254 let cfg = HqqConfig {
255 bits,
256 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
257 axis: HqqAxis::Zero,
258 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
259 round_zeros: false,
260 channel_wise: true,
261 };
262 let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
263 if let Some(bias) = &self.b {
264 let bias = bias
265 .to_device(&device)?
266 .to_dtype(res.dtype_and_device().0)?;
267 Ok(Arc::new(res.with_bias(bias)))
268 } else {
269 Ok(Arc::new(res))
270 }
271 }
272 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
273 let _acquired_quantize_guard = guard.acquire(&device);
274 if imatrix_weight.is_some() {
275 candle_core::bail!("AFQ does not support imatrix.");
277 }
278
279 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
280 let bits = match dtype.unwrap() {
281 IsqType::AFQ8 => AfqBits::Eight,
282 IsqType::AFQ6 => AfqBits::Six,
283 IsqType::AFQ4 => AfqBits::Four,
284 IsqType::AFQ3 => AfqBits::Three,
285 IsqType::AFQ2 => AfqBits::Two,
286 _ => unreachable!(),
287 };
288
289 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
290 weight: self.w.to_device(&device)?,
291 bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
292 bits,
293 group_size: AfqGroupSize::default(),
294 })?))
295 }
296 Some(
297 IsqType::Q2K
298 | IsqType::Q3K
299 | IsqType::Q4K
300 | IsqType::Q4_0
301 | IsqType::Q4_1
302 | IsqType::Q5K
303 | IsqType::Q5_0
304 | IsqType::Q5_1
305 | IsqType::Q6K
306 | IsqType::Q8K
307 | IsqType::Q8_0
308 | IsqType::Q8_1,
309 ) => {
310 let dtype: GgmlDType = dtype.unwrap().try_into()?;
311 let res = if let Some(imatrix_weight) = imatrix_weight {
312 generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
313 } else {
314 generate_isq!(self.w, device, dtype, n_quantized, guard)
315 };
316 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
317 q_weight: res,
318 b: self
319 .b
320 .as_ref()
321 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
322 })?))
323 }
324 Some(IsqType::F8E4M3) => {
325 let _acquired_quantize_guard = guard.acquire(&device);
326 if imatrix_weight.is_some() {
327 candle_core::bail!("F8E4M3 does not support imatrix.");
329 }
330
331 let w = self.w.to_device(&device)?;
332 let b = if let Some(b) = &self.b {
333 Some(b.to_device(&device)?)
334 } else {
335 None
336 };
337 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
338 lin: Linear::new(w, b),
339 dtype: DType::F8E4M3,
340 })?))
341 }
342 None => {
343 let _acquired_quantize_guard = guard.acquire(&device);
344 let w = self.w.to_device(&device)?;
347 let b = if let Some(b) = &self.b {
348 Some(b.to_device(&device)?)
349 } else {
350 None
351 };
352 Ok(Arc::new(UnquantLinear::new(
353 QuantMethodConfig::Unquantized(Linear::new(w, b)),
354 )?))
355 }
356 }
357 }
358
359 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
360 Some((self.w.clone(), self.b.clone()))
361 }
362
363 fn begin_track_stats(&mut self) -> Result<()> {
364 self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
365 Ok(())
366 }
367
368 fn end_track_stats(&self) -> Result<Tensor> {
369 if let Some(stats) = &self.stats {
370 let imatrix = stats.compute_imatrix()?;
371 stats.clear()?;
372 Ok(imatrix)
373 } else {
374 candle_core::bail!("`{}` does not support tracking stats.", self.name())
375 }
376 }
377}
378
379impl QuantizedSerde for UnquantLinear {
394 fn isq_serde_supported(&self) -> bool {
395 true
396 }
397 fn name(&self) -> &'static str {
398 "unquant-linear"
399 }
400 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
401 self.serialize_with_bias(self.b.clone())
402 }
403 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
404 let mut buffer = Vec::new();
405
406 buffer.extend(&UQFF_VERSION.to_le_bytes());
409
410 buffer.push(QuantizedSerdeType::Unquant as u8);
412
413 buffer.push(bias.is_some() as u8);
415
416 serialize_tensor(&mut buffer, &self.w)?;
418
419 if let Some(bias) = &bias {
420 serialize_tensor(&mut buffer, bias)?;
422 }
423
424 Ok(Cow::from(buffer))
425 }
426
427 fn deserialize(
428 data: Cow<[u8]>,
429 device: &Device,
430 _comm: &Arc<crate::Comm>,
431 guard: QuantizeOntoGuard,
432 ) -> Result<Arc<dyn QuantMethod>>
433 where
434 Self: Sized,
435 {
436 let mut buffer = Cursor::new(data);
437
438 let version = buffer.read_u32::<LittleEndian>()?;
439 if let Err(e) = version_is_compatible(version) {
440 return Err(candle_core::Error::wrap(e));
441 }
442
443 let isq_type = buffer.read_u8()? as usize;
444 if isq_type != QuantizedSerdeType::Unquant as usize {
445 candle_core::bail!(
446 "ISQ type ({isq_type}) doesn't match expected type {}",
447 QuantizedSerdeType::Unquant as usize
448 );
449 }
450
451 let has_bias = buffer.read_u8()? != 0;
452
453 let _acquired_load_guard = guard.acquire(device);
454 let w = deserialize_tensor(&mut buffer, device)?;
455
456 let b = if has_bias {
457 Some(deserialize_tensor(&mut buffer, device)?)
458 } else {
459 None
460 };
461
462 Ok(Arc::new(Self { w, b, stats: None }))
463 }
464 fn deserialize_ext_bias(
465 data: Cow<[u8]>,
466 device: &Device,
467 guard: QuantizeOntoGuard,
468 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
469 where
470 Self: Sized,
471 {
472 let mut buffer = Cursor::new(data);
473
474 let version = buffer.read_u32::<LittleEndian>()?;
475 if let Err(e) = version_is_compatible(version) {
476 return Err(candle_core::Error::wrap(e));
477 }
478
479 let isq_type = buffer.read_u8()? as usize;
480 if isq_type != QuantizedSerdeType::Unquant as usize {
481 candle_core::bail!(
482 "ISQ type ({isq_type}) doesn't match expected type {}",
483 QuantizedSerdeType::Unquant as usize
484 );
485 }
486
487 let has_bias = buffer.read_u8()? != 0;
488
489 let _acquired_load_guard = guard.acquire(device);
490 let w = deserialize_tensor(&mut buffer, device)?;
491
492 let b = if has_bias {
493 Some(deserialize_tensor(&mut buffer, device)?)
494 } else {
495 None
496 };
497
498 Ok((
499 Arc::new(Self {
500 w,
501 b: None,
502 stats: None,
503 }),
504 b,
505 ))
506 }
507}