1use byteorder::{LittleEndian, ReadBytesExt};
2use candle_core::{DType, Device, Result, Shape, Tensor};
3
4#[cfg(feature = "cuda")]
5use candle_core::{
6 cuda::{cudarc::driver::DevicePtr, CudaStorageSlice},
7 from_storage_no_op, CudaStorage, Storage,
8};
9
10#[cfg(feature = "metal")]
11use candle_core::{from_storage_no_op, Storage};
12
13use candle_nn::Linear;
14#[cfg(feature = "cuda")]
15use half::{bf16, f16};
16use std::{
17 borrow::Cow,
18 io::Cursor,
19 num::NonZeroUsize,
20 sync::{atomic::AtomicUsize, Arc},
21};
22
23use crate::{
24 utils::{
25 deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
26 BitWiseOp, LeftshiftOp, UQFF_VERSION,
27 },
28 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
29 UnquantLinear,
30};
31
32#[cfg(feature = "cuda")]
33use crate::utils::get_cuda_device;
34
35#[cfg(feature = "cuda")]
36use ffi::{eight_bit, four_bit, one_bit, three_bit, two_bit};
37
38#[cfg(feature = "cuda")]
39mod ffi;
40
41#[cfg(feature = "cuda")]
42mod bitpack_ffi;
43
44#[cfg(not(feature = "cuda"))]
45mod hqq_op;
46
47mod optimize;
48mod quantize;
49
50pub(crate) const ISQ_HQQ_GROUP_SIZE: usize = 64;
51pub(crate) const ISQ_HQQ_DEFAULT_OPT_STEPS: Option<usize> = Some(10);
52pub(crate) const OPTIMIZER_HQQ_DEFAULT_STEPS: usize = 20;
53
54#[cfg(feature = "cuda")]
55macro_rules! dequant_for_dtype {
56 ($this:expr, w=$wq_t:ty, sz=$scale_t:ty, $dtype:ident, pack=$pack:expr, $dev:expr, $bit_thing:ident, $postfix:tt) => {{
57 paste::paste! {
58 let (wq, _) = $this.w_q.storage_and_layout();
59 let wq = match &*wq {
60 candle_core::Storage::Cuda(s) => s,
61 _ => candle_core::bail!("wq must be a cuda tensor"),
62 };
63 let (w_slice, _w_guard) = crate::utils::slice_ptr(wq.as_cuda_slice::<$wq_t>()?, $this.w_q.layout().start_offset());
64
65 let (scale, _) = $this.scales.storage_and_layout();
66 let scale = match &*scale {
67 candle_core::Storage::Cuda(s) => s,
68 _ => candle_core::bail!("scale must be a cuda tensor"),
69 };
70 let (scale_slice, _scale_guard) = crate::utils::slice_ptr(scale.as_cuda_slice::<$scale_t>()?, $this.scales.layout().start_offset());
71
72 let (zero, _) = $this.zeros.storage_and_layout();
73 let zero = match &*zero {
74 candle_core::Storage::Cuda(s) => s,
75 _ => candle_core::bail!("zero must be a cuda tensor"),
76 };
77 let (zero_slice, _zero_guard) = crate::utils::slice_ptr(zero.as_cuda_slice::<$scale_t>()?, $this.zeros.layout().start_offset());
78
79 let (h, w) = $this.w_q.dims2()?;
80 let num_packed_elems = $pack;
81 let out_shape = Shape::from_dims(&[num_packed_elems * h, w]);
82
83 let out = unsafe { $dev.alloc::<$scale_t>(out_shape.elem_count())? };
84 let (out_ptr, out_guard) = out.device_ptr(out.stream());
85 unsafe {
86 $bit_thing::[< dequantize_ $postfix >](
87 w_slice as *const $wq_t,
88 scale_slice as *const $scale_t,
89 zero_slice as *const $scale_t,
90 out_ptr as *mut $scale_t,
91 h as i32,
92 w as i32,
93 );
94 }
95 drop(out_guard);
96
97 let storage = CudaStorage {
98 slice: CudaStorageSlice::$dtype(out),
99 device: $dev.clone(),
100 };
101 let storage = Storage::Cuda(storage);
102
103 from_storage_no_op(storage, out_shape, false)
104 }
105 }};
106}
107
108#[derive(Debug, Clone, Copy)]
109pub enum HqqAxis {
110 Zero = 0,
111 One = 1,
112}
113
114impl TryFrom<usize> for HqqAxis {
115 type Error = candle_core::Error;
116 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
117 match value {
118 0 => Ok(Self::Zero),
119 1 => Ok(Self::One),
120 other => candle_core::bail!("Unexpected value for HQQ axis {other}"),
121 }
122 }
123}
124
125#[derive(Debug, Clone, Copy)]
126pub enum HqqBits {
127 Eight = 8,
128 Four = 4,
129 Three = 3,
130 Two = 2,
131 One = 1,
132}
133
134impl TryFrom<usize> for HqqBits {
135 type Error = candle_core::Error;
136 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
137 match value {
138 8 => Ok(Self::Eight),
139 4 => Ok(Self::Four),
140 3 => Ok(Self::Three),
141 2 => Ok(Self::Two),
142 1 => Ok(Self::One),
143 other => candle_core::bail!("Unexpected value for HQQ bits {other}"),
144 }
145 }
146}
147
148impl HqqBits {
149 pub(crate) fn bitpack_type(&self) -> impl Fn(Tensor) -> Result<Tensor> {
151 match self {
152 Self::Eight => |wq: Tensor| -> Result<Tensor> {
153 #[allow(unused_variables)]
154 let device = wq.device();
155
156 #[cfg(feature = "cuda")]
157 if device.is_cuda() {
158 let dev = get_cuda_device(&wq)?;
160 let wq = wq.to_dtype(DType::U8)?;
161 let (wq_storage, _) = wq.storage_and_layout();
162 let wq_storage = match &*wq_storage {
163 Storage::Cuda(s) => s,
164 _ => candle_core::bail!("Expected CUDA storage"),
165 };
166
167 let output_shape = wq.shape().clone();
168 let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
169
170 unsafe {
171 let (output_ptr, output_guard) = output.device_ptr(output.stream());
172 let (input_ptr, _input_guard) = crate::utils::slice_ptr(
173 wq_storage.as_cuda_slice::<u8>()?,
174 wq.layout().start_offset(),
175 );
176
177 bitpack_ffi::launch_pack_8bit_kernel(
178 input_ptr as *const u8,
179 output_ptr as *mut u8,
180 output_shape.elem_count(),
181 dev.cuda_stream().cu_stream(),
182 );
183 drop(output_guard);
184 }
185
186 let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
187 let storage = Storage::Cuda(storage);
188 return Ok(from_storage_no_op(storage, output_shape, false));
189 }
190
191 #[cfg(feature = "metal")]
192 if device.is_metal() {
193 use candle_core::MetalStorage;
194
195 let dev = device.as_metal_device()?;
196 let command_buffer = dev.command_buffer()?;
197 command_buffer.set_label("hqq_pack_8bit");
198
199 let (wq_storage, _wq_layout) = wq.storage_and_layout();
200 let wq_storage = match &*wq_storage {
201 Storage::Metal(s) => s,
202 _ => candle_core::bail!("Expected Metal storage"),
203 };
204
205 let output_shape = wq.shape().clone();
206 let output = dev.new_buffer(
207 output_shape.elem_count(),
208 DType::U8,
209 "hqq_pack_8bit_output",
210 )?;
211
212 crate::metal_kernels::call_hqq_pack_8bit(
213 dev.device(),
214 &command_buffer,
215 &crate::metal_kernels::Kernels::new(),
216 wq_storage.buffer(),
217 &output,
218 output_shape.elem_count(),
219 )
220 .map_err(candle_core::Error::wrap)?;
221
222 let storage = MetalStorage::new(
223 output,
224 dev.clone(),
225 output_shape.elem_count(),
226 DType::U8,
227 );
228 let storage = Storage::Metal(storage);
229
230 return Ok(from_storage_no_op(storage, output_shape, false));
231 }
232
233 wq.to_dtype(DType::U8)
234 },
235 Self::Four => |wq_in: Tensor| -> Result<Tensor> {
236 #[allow(unused_variables)]
237 let device = wq_in.device();
238
239 #[cfg(feature = "cuda")]
240 if device.is_cuda() {
241 let dev = get_cuda_device(&wq_in)?;
243 let wq = wq_in.to_dtype(DType::U8)?;
244 let (wq_storage, _) = wq.storage_and_layout();
245 let wq_storage = match &*wq_storage {
246 Storage::Cuda(s) => s,
247 _ => candle_core::bail!("Expected CUDA storage"),
248 };
249
250 let output_height = wq.dims()[0] / 2;
251 let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
252 let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
253
254 unsafe {
255 let (output_ptr, output_guard) = output.device_ptr(output.stream());
256 let (input_ptr, _input_guard) = crate::utils::slice_ptr(
257 wq_storage.as_cuda_slice::<u8>()?,
258 wq.layout().start_offset(),
259 );
260
261 bitpack_ffi::launch_pack_4bit_kernel(
262 input_ptr as *const u8,
263 output_ptr as *mut u8,
264 wq.dims()[0],
265 wq.dims()[1],
266 dev.cuda_stream().cu_stream(),
267 );
268 drop(output_guard);
269 }
270
271 let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
272 let storage = Storage::Cuda(storage);
273 return Ok(from_storage_no_op(storage, output_shape, false));
274 }
275
276 #[cfg(feature = "metal")]
277 if device.is_metal() {
278 use candle_core::MetalStorage;
279
280 let dev = device.as_metal_device()?;
281 let command_buffer = dev.command_buffer()?;
282 command_buffer.set_label("hqq_pack_4bit");
283
284 let wq = wq_in.to_dtype(DType::U8)?;
285 let (wq_storage, _wq_layout) = wq.storage_and_layout();
286 let wq_storage = match &*wq_storage {
287 Storage::Metal(s) => s,
288 _ => candle_core::bail!("Expected Metal storage"),
289 };
290
291 let output_height = wq.dims()[0] / 2;
292 let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
293 let output = dev.new_buffer(
294 output_shape.elem_count(),
295 DType::U8,
296 "hqq_pack_4bit_output",
297 )?;
298
299 crate::metal_kernels::call_hqq_pack_4bit(
300 dev.device(),
301 &command_buffer,
302 &crate::metal_kernels::Kernels::new(),
303 wq_storage.buffer(),
304 &output,
305 wq.dims()[0],
306 wq.dims()[1],
307 )
308 .map_err(candle_core::Error::wrap)?;
309
310 let storage = MetalStorage::new(
311 output,
312 dev.clone(),
313 output_shape.elem_count(),
314 DType::U8,
315 );
316 let storage = Storage::Metal(storage);
317
318 return Ok(from_storage_no_op(storage, output_shape, false));
319 }
320
321 let wq = wq_in.to_dtype(DType::U8)?;
323 let step = (wq.dims()[0] as f64 / 2.) as usize;
324
325 let a = wq.narrow(0, 0, step)?;
326 let b = wq.narrow(0, step, step)?;
327 a.leftshift(4)?.bitwise_or(&b)
328 },
329 Self::Two => |wq_in: Tensor| -> Result<Tensor> {
330 #[allow(unused_variables)]
331 let device = wq_in.device();
332
333 #[cfg(feature = "cuda")]
334 if device.is_cuda() {
335 let dev = get_cuda_device(&wq_in)?;
337 let wq = wq_in.to_dtype(DType::U8)?;
338 let (wq_storage, _) = wq.storage_and_layout();
339 let wq_storage = match &*wq_storage {
340 Storage::Cuda(s) => s,
341 _ => candle_core::bail!("Expected CUDA storage"),
342 };
343
344 let output_height = wq.dims()[0] / 4;
345 let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
346 let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
347
348 unsafe {
349 let (output_ptr, output_guard) = output.device_ptr(output.stream());
350 let (input_ptr, _input_guard) = crate::utils::slice_ptr(
351 wq_storage.as_cuda_slice::<u8>()?,
352 wq.layout().start_offset(),
353 );
354
355 bitpack_ffi::launch_pack_2bit_kernel(
356 input_ptr as *const u8,
357 output_ptr as *mut u8,
358 wq.dims()[0],
359 wq.dims()[1],
360 dev.cuda_stream().cu_stream(),
361 );
362 drop(output_guard);
363 }
364
365 let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
366 let storage = Storage::Cuda(storage);
367 Ok(from_storage_no_op(storage, output_shape, false))
368 } else {
369 let wq = wq_in.to_dtype(DType::U8)?;
371 let step = (wq.dims()[0] as f64 / 4.) as usize;
372
373 let a = wq.narrow(0, 0, step)?;
374 let b = wq.narrow(0, step, step)?;
375 let c = wq.narrow(0, step * 2, step)?;
376 let d = wq.narrow(0, step * 3, step)?;
377
378 a.leftshift(6)?
379 .bitwise_or(&b.leftshift(4)?)?
380 .bitwise_or(&c.leftshift(2)?)?
381 .bitwise_or(&d)
382 }
383 #[cfg(not(feature = "cuda"))]
384 {
385 let wq = wq_in.to_dtype(DType::U8)?;
386 let step = (wq.dims()[0] as f64 / 4.) as usize;
387
388 let a = wq.narrow(0, 0, step)?;
389 let b = wq.narrow(0, step, step)?;
390 let c = wq.narrow(0, step * 2, step)?;
391 let d = wq.narrow(0, step * 3, step)?;
392
393 a.leftshift(6)?
394 .bitwise_or(&b.leftshift(4)?)?
395 .bitwise_or(&c.leftshift(2)?)?
396 .bitwise_or(&d)
397 }
398 },
399 Self::Three => |wq_in: Tensor| -> Result<Tensor> {
400 let device = wq_in.device();
401
402 let padded_height = (10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize;
404 let wq = Tensor::zeros((padded_height, wq_in.dims()[1]), DType::U32, device)?;
405 let wq = wq.slice_assign(
406 &[0..wq_in.dims()[0], 0..wq.dims()[1]],
407 &wq_in.to_dtype(DType::U32)?,
408 )?;
409
410 #[cfg(feature = "cuda")]
411 if device.is_cuda() {
412 let dev = get_cuda_device(&wq)?;
414 let (wq_storage, _) = wq.storage_and_layout();
415 let wq_storage = match &*wq_storage {
416 Storage::Cuda(s) => s,
417 _ => candle_core::bail!("Expected CUDA storage"),
418 };
419
420 let output_height = padded_height / 10;
421 let output_shape = Shape::from_dims(&[output_height, wq_in.dims()[1]]);
422 let output = unsafe { dev.alloc::<i32>(output_shape.elem_count())? };
423
424 unsafe {
425 let (output_ptr, output_guard) = output.device_ptr(output.stream());
426 let (input_ptr, _input_guard) = crate::utils::slice_ptr(
427 wq_storage.as_cuda_slice::<u32>()?,
428 wq.layout().start_offset(),
429 );
430
431 bitpack_ffi::launch_pack_3bit_kernel(
432 input_ptr as *const u32,
433 output_ptr as *mut i32,
434 padded_height,
435 wq_in.dims()[1],
436 dev.cuda_stream().cu_stream(),
437 );
438 drop(output_guard);
439 }
440
441 let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
442 let storage = Storage::Cuda(storage);
443 return Ok(from_storage_no_op(storage, output_shape, false));
444 }
445
446 let wq = if wq.device().is_metal() {
448 let cpu_wq = wq.to_device(&Device::Cpu)?;
450 cpu_wq.to_dtype(DType::I32)?.to_device(wq.device())?
451 } else {
452 wq.to_dtype(DType::I32)?
453 };
454 let step = (wq.dims()[0] as f64 / 10.) as usize;
455
456 let a = wq.narrow(0, 0, step)?;
457 let b = wq.narrow(0, step, step)?;
458 let c = wq.narrow(0, step * 2, step)?;
459 let d = wq.narrow(0, step * 3, step)?;
460 let e = wq.narrow(0, step * 4, step)?;
461 let f = wq.narrow(0, step * 5, step)?;
462 let g = wq.narrow(0, step * 6, step)?;
463 let h = wq.narrow(0, step * 7, step)?;
464 let i = wq.narrow(0, step * 8, step)?;
465 let j = wq.narrow(0, step * 9, step)?;
466
467 a.leftshift(27)?
468 .bitwise_or(&b.leftshift(24)?)?
469 .bitwise_or(&c.leftshift(21)?)?
470 .bitwise_or(&d.leftshift(18)?)?
471 .bitwise_or(&e.leftshift(15)?)?
472 .bitwise_or(&f.leftshift(12)?)?
473 .bitwise_or(&g.leftshift(9)?)?
474 .bitwise_or(&h.leftshift(6)?)?
475 .bitwise_or(&i.leftshift(3)?)?
476 .bitwise_or(&j)
477 },
478 Self::One => |wq_in: Tensor| -> Result<Tensor> {
479 #[allow(unused_variables)]
480 let device = wq_in.device();
481
482 #[cfg(feature = "cuda")]
483 if device.is_cuda() {
484 let dev = get_cuda_device(&wq_in)?;
486 let wq = wq_in.to_dtype(DType::U8)?;
487 let (wq_storage, _) = wq.storage_and_layout();
488 let wq_storage = match &*wq_storage {
489 Storage::Cuda(s) => s,
490 _ => candle_core::bail!("Expected CUDA storage"),
491 };
492
493 let output_height = wq.dims()[0] / 8;
494 let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
495 let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
496
497 unsafe {
498 let (output_ptr, output_guard) = output.device_ptr(output.stream());
499 let (input_ptr, _input_guard) = crate::utils::slice_ptr(
500 wq_storage.as_cuda_slice::<u8>()?,
501 wq.layout().start_offset(),
502 );
503
504 bitpack_ffi::launch_pack_1bit_kernel(
505 input_ptr as *const u8,
506 output_ptr as *mut u8,
507 wq.dims()[0],
508 wq.dims()[1],
509 dev.cuda_stream().cu_stream(),
510 );
511 drop(output_guard);
512 }
513
514 let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
515 let storage = Storage::Cuda(storage);
516 Ok(from_storage_no_op(storage, output_shape, false))
517 } else {
518 let wq = wq_in.to_dtype(DType::U8)?;
520 let step = (wq.dims()[0] as f64 / 8.) as usize;
521
522 let a = wq.narrow(0, 0, step)?;
523 let b = wq.narrow(0, step, step)?;
524 let c = wq.narrow(0, step * 2, step)?;
525 let d = wq.narrow(0, step * 3, step)?;
526 let e = wq.narrow(0, step * 4, step)?;
527 let f = wq.narrow(0, step * 5, step)?;
528 let g = wq.narrow(0, step * 6, step)?;
529 let h = wq.narrow(0, step * 7, step)?;
530
531 a.leftshift(7)?
532 .bitwise_or(&b.leftshift(6)?)?
533 .bitwise_or(&c.leftshift(5)?)?
534 .bitwise_or(&d.leftshift(4)?)?
535 .bitwise_or(&e.leftshift(3)?)?
536 .bitwise_or(&f.leftshift(2)?)?
537 .bitwise_or(&g.leftshift(1)?)?
538 .bitwise_or(&h)
539 }
540 #[cfg(not(feature = "cuda"))]
541 {
542 let wq = wq_in.to_dtype(DType::U8)?;
543 let step = (wq.dims()[0] as f64 / 8.) as usize;
544
545 let a = wq.narrow(0, 0, step)?;
546 let b = wq.narrow(0, step, step)?;
547 let c = wq.narrow(0, step * 2, step)?;
548 let d = wq.narrow(0, step * 3, step)?;
549 let e = wq.narrow(0, step * 4, step)?;
550 let f = wq.narrow(0, step * 5, step)?;
551 let g = wq.narrow(0, step * 6, step)?;
552 let h = wq.narrow(0, step * 7, step)?;
553
554 a.leftshift(7)?
555 .bitwise_or(&b.leftshift(6)?)?
556 .bitwise_or(&c.leftshift(5)?)?
557 .bitwise_or(&d.leftshift(4)?)?
558 .bitwise_or(&e.leftshift(3)?)?
559 .bitwise_or(&f.leftshift(2)?)?
560 .bitwise_or(&g.leftshift(1)?)?
561 .bitwise_or(&h)
562 }
563 },
564 }
565 }
566}
567
568#[derive(Debug, Clone, Copy)]
569pub struct HqqConfig {
570 pub bits: HqqBits,
571 pub group_size: NonZeroUsize,
572 pub axis: HqqAxis,
573 pub optimization_steps: Option<usize>,
574 pub round_zeros: bool, pub channel_wise: bool, }
577
578#[derive(Debug)]
579pub struct HqqLayer {
580 pub(crate) w_q: Tensor,
581 pub(crate) zeros: Tensor,
582 pub(crate) scales: Tensor,
583 pub(crate) bias: Option<Tensor>,
584 pub(crate) w_shape: Shape,
585 pub(crate) cfg: HqqConfig,
586}
587
588impl HqqLayer {
589 #[cfg(not(feature = "cuda"))]
591 fn dequantize(&self) -> Result<Tensor> {
592 use crate::hqq::hqq_op::{Dequant1Bit, Dequant2Bit, Dequant3Bit, Dequant4Bit, Dequant8Bit};
593
594 match (self.scales.dtype(), self.zeros.dtype()) {
595 (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
596 (a, b) => {
597 candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
598 }
599 }
600 if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
601 {
602 candle_core::bail!("All tensors must be contiguous!");
603 }
604 if self.cfg.axis as usize != 0 {
605 candle_core::bail!(
606 "CPU HQQ dequantization requires axis == 0, got {}.",
607 self.cfg.axis as usize
608 );
609 }
610 let (h, w) = self.w_q.dims2()?;
611
612 match self.cfg.bits as usize {
613 8 => self
614 .w_q
615 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant8Bit { h, w })?
616 .reshape(&self.w_shape),
617 4 => self
618 .w_q
619 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant4Bit { h, w })?
620 .reshape(&self.w_shape),
621 3 => self
622 .w_q
623 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant3Bit { h, w })?
624 .reshape(&self.w_shape),
625 2 => self
626 .w_q
627 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant2Bit { h, w })?
628 .reshape(&self.w_shape),
629 1 => self
630 .w_q
631 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant1Bit { h, w })?
632 .reshape(&self.w_shape),
633 b => candle_core::bail!("Unreachable bits {b}"),
634 }
635 }
636
637 #[cfg(feature = "cuda")]
639 fn dequantize(&self) -> Result<Tensor> {
640 match (self.scales.dtype(), self.zeros.dtype()) {
641 (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
642 (a, b) => {
643 candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
644 }
645 }
646 if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
647 {
648 candle_core::bail!("All tensors must be contiguous!");
649 }
650 if self.cfg.axis as usize != 0 {
651 candle_core::bail!(
652 "CUDA HQQ dequantization requires axis == 0, got {}.",
653 self.cfg.axis as usize
654 );
655 }
656 let dev = get_cuda_device(&self.w_q)?;
657
658 let inner = match (self.cfg.bits as usize, self.scales.dtype()) {
659 (8, DType::F32) => {
661 dequant_for_dtype!(
662 self,
663 w = u8,
664 sz = f32,
665 F32,
666 pack = 1,
667 dev,
668 eight_bit,
669 8bit_u8_kernel_f32
670 )
671 }
672 (8, DType::F16) => {
673 dequant_for_dtype!(
674 self,
675 w = u8,
676 sz = f16,
677 F16,
678 pack = 1,
679 dev,
680 eight_bit,
681 8bit_u8_kernel_f16
682 )
683 }
684 (8, DType::BF16) => {
685 dequant_for_dtype!(
686 self,
687 w = u8,
688 sz = bf16,
689 BF16,
690 pack = 1,
691 dev,
692 eight_bit,
693 8bit_u8_kernel_bf16
694 )
695 }
696
697 (4, DType::F32) => {
699 dequant_for_dtype!(
700 self,
701 w = u8,
702 sz = f32,
703 F32,
704 pack = 2,
705 dev,
706 four_bit,
707 4bit_u8_kernel_f32
708 )
709 }
710 (4, DType::F16) => {
711 dequant_for_dtype!(
712 self,
713 w = u8,
714 sz = f16,
715 F16,
716 pack = 2,
717 dev,
718 four_bit,
719 4bit_u8_kernel_f16
720 )
721 }
722 (4, DType::BF16) => {
723 dequant_for_dtype!(
724 self,
725 w = u8,
726 sz = bf16,
727 BF16,
728 pack = 2,
729 dev,
730 four_bit,
731 4bit_u8_kernel_bf16
732 )
733 }
734
735 (3, DType::F32) => {
738 let res = dequant_for_dtype!(
739 self,
740 w = i32,
741 sz = f32,
742 F32,
743 pack = 10,
744 dev,
745 three_bit,
746 3bit_32_kernel_f32
747 );
748 res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
749 }
750 (3, DType::F16) => {
751 let res = dequant_for_dtype!(
752 self,
753 w = i32,
754 sz = f16,
755 F16,
756 pack = 10,
757 dev,
758 three_bit,
759 3bit_32_kernel_f16
760 );
761 res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
762 }
763 (3, DType::BF16) => {
764 let res = dequant_for_dtype!(
765 self,
766 w = i32,
767 sz = bf16,
768 BF16,
769 pack = 10,
770 dev,
771 three_bit,
772 3bit_32_kernel_bf16
773 );
774 res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
775 }
776
777 (2, DType::F32) => {
779 dequant_for_dtype!(
780 self,
781 w = u8,
782 sz = f32,
783 F32,
784 pack = 4,
785 dev,
786 two_bit,
787 2bit_u8_kernel_f32
788 )
789 }
790 (2, DType::F16) => {
791 dequant_for_dtype!(
792 self,
793 w = u8,
794 sz = f16,
795 F16,
796 pack = 4,
797 dev,
798 two_bit,
799 2bit_u8_kernel_f16
800 )
801 }
802 (2, DType::BF16) => {
803 dequant_for_dtype!(
804 self,
805 w = u8,
806 sz = bf16,
807 BF16,
808 pack = 4,
809 dev,
810 two_bit,
811 2bit_u8_kernel_bf16
812 )
813 }
814
815 (1, DType::F32) => {
817 dequant_for_dtype!(
818 self,
819 w = u8,
820 sz = f32,
821 F32,
822 pack = 8,
823 dev,
824 one_bit,
825 1bit_u8_kernel_f32
826 )
827 }
828 (1, DType::F16) => {
829 dequant_for_dtype!(
830 self,
831 w = u8,
832 sz = f16,
833 F16,
834 pack = 8,
835 dev,
836 one_bit,
837 1bit_u8_kernel_f16
838 )
839 }
840 (1, DType::BF16) => {
841 dequant_for_dtype!(
842 self,
843 w = u8,
844 sz = bf16,
845 BF16,
846 pack = 8,
847 dev,
848 one_bit,
849 1bit_u8_kernel_bf16
850 )
851 }
852 (bits, dtype) => candle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}"),
853 };
854 inner.reshape(&self.w_shape)
855 }
856
857 fn dequantize_matmul(&self, xs: &Tensor) -> Result<Tensor> {
858 let w = self.dequantize()?;
859 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
861 w,
862 self.bias.clone(),
863 )))?;
864 unquant.forward(xs)
865 }
866
867 pub fn with_bias(mut self, bias: Tensor) -> Self {
868 self.bias = Some(bias);
869 self
870 }
871}
872
873impl QuantMethod for HqqLayer {
874 fn new(method: QuantMethodConfig) -> Result<Self>
875 where
876 Self: Sized,
877 {
878 match method {
879 QuantMethodConfig::Gguf { .. }
880 | QuantMethodConfig::Unquantized(_)
881 | QuantMethodConfig::GptqAwq { .. }
882 | QuantMethodConfig::Dummy
883 | QuantMethodConfig::FP8 { .. }
884 | QuantMethodConfig::Bnb { .. }
885 | QuantMethodConfig::BlockwiseFP8 { .. }
886 | QuantMethodConfig::Afq { .. } => {
887 unreachable!()
888 }
889 QuantMethodConfig::Hqq {
890 tensor,
891 bits,
892 group_size,
893 axis,
894 optimization_steps,
895 round_zeros,
896 channel_wise,
897 bias,
898 } => {
899 let cfg = HqqConfig {
900 bits,
901 group_size,
902 axis,
903 optimization_steps,
904 round_zeros: round_zeros.unwrap_or(false),
905 channel_wise: channel_wise.unwrap_or(true),
906 };
907
908 let this = Self::quantize(&tensor, tensor.device(), cfg)?;
909 if let Some(bias) = bias {
910 Ok(this.with_bias(bias))
911 } else {
912 Ok(this)
913 }
914 }
915 }
916 }
917
918 fn dequantize_w(&self) -> Result<Tensor> {
919 self.dequantize()
920 }
921
922 fn forward(&self, a: &Tensor) -> Result<Tensor> {
923 self.dequantize_matmul(a)
930 }
931
932 fn quantized_act_type(&self) -> Option<DType> {
933 Some(self.scales.dtype())
934 }
935
936 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
937 candle_core::bail!("HQQ quantization does not support adding weight delta.")
938 }
939
940 fn dtype_and_device(&self) -> (DType, Device) {
941 (self.scales.dtype(), self.scales.device().clone())
942 }
943
944 fn apply_isq(
945 self: Arc<Self>,
946 dtype: Option<IsqType>,
947 device: Device,
948 n_quantized: &AtomicUsize,
949 imatrix_weight: Option<Vec<f32>>,
950 guard: QuantizeOntoGuard,
951 ) -> Result<Arc<dyn QuantMethod>> {
952 let _acquired_quantize_guard = guard.acquire(&device);
953 if imatrix_weight.is_some() {
954 candle_core::bail!("HQQ does not support imatrix.");
956 }
957
958 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
959 let bits = match dtype {
960 Some(IsqType::HQQ8) => HqqBits::Eight,
961 Some(IsqType::HQQ4) => HqqBits::Four,
962 _ => candle_core::bail!("Expected a HQQ ISQ type."),
966 };
967 let cfg = HqqConfig {
968 bits,
969 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
970 axis: HqqAxis::Zero,
971 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
972 round_zeros: false,
973 channel_wise: true,
974 };
975 let dequant = self.dequantize()?;
976 let res = Self::quantize(&dequant, &device, cfg)?;
977 if let Some(ref bias) = self.bias {
978 let bias = bias
979 .to_device(&device)?
980 .to_dtype(res.dtype_and_device().0)?;
981 Ok(Arc::new(res.with_bias(bias)))
982 } else {
983 Ok(Arc::new(res))
984 }
985 }
986}
987
988impl QuantizedSerde for HqqLayer {
1025 fn isq_serde_supported(&self) -> bool {
1026 true
1027 }
1028 fn name(&self) -> &'static str {
1029 "hqq"
1030 }
1031 fn serialize(&self) -> Result<Cow<[u8]>> {
1032 self.serialize_with_bias(self.bias.clone())
1033 }
1034 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
1035 let mut buffer = Vec::new();
1036
1037 buffer.extend(&UQFF_VERSION.to_le_bytes());
1039
1040 buffer.push(QuantizedSerdeType::Hqq as u8);
1042
1043 buffer.push(bias.is_some() as u8);
1045
1046 serialize_tensor(&mut buffer, &self.w_q)?;
1047 serialize_tensor(&mut buffer, &self.scales)?;
1048 serialize_tensor(&mut buffer, &self.zeros)?;
1049
1050 let w_shape = self.w_shape.dims();
1051 let shape_len = w_shape.len();
1052 if shape_len > u32::MAX as usize {
1053 candle_core::bail!(
1054 "Weight tensor has too many dimensions for UQFF format: {} exceeds u32::MAX",
1055 shape_len
1056 );
1057 }
1058 buffer.extend((shape_len as u32).to_le_bytes());
1059 for dim in w_shape {
1060 if *dim > u32::MAX as usize {
1061 candle_core::bail!(
1062 "Weight tensor dimension too large for UQFF format: {} exceeds u32::MAX",
1063 dim
1064 );
1065 }
1066 buffer.extend((*dim as u32).to_le_bytes());
1067 }
1068
1069 buffer.push(self.cfg.bits as u8);
1071 let group_size = <NonZeroUsize as Into<usize>>::into(self.cfg.group_size);
1072 if group_size > u32::MAX as usize {
1073 candle_core::bail!(
1074 "HQQ group size too large for UQFF format: {} exceeds u32::MAX",
1075 group_size
1076 );
1077 }
1078 buffer.extend(&(group_size as u32).to_le_bytes());
1079 buffer.push(self.cfg.axis as u8);
1080 let opt_steps = self.cfg.optimization_steps.unwrap_or(0);
1083 if opt_steps > u32::MAX as usize {
1084 candle_core::bail!(
1085 "HQQ optimization steps too large for UQFF format: {} exceeds u32::MAX",
1086 opt_steps
1087 );
1088 }
1089 buffer.extend(&(opt_steps as u32).to_le_bytes());
1090 buffer.push(self.cfg.round_zeros as u8);
1091 buffer.push(self.cfg.channel_wise as u8);
1092
1093 if let Some(bias) = &bias {
1094 serialize_tensor(&mut buffer, bias)?;
1096 }
1097
1098 Ok(Cow::from(buffer))
1099 }
1100
1101 fn deserialize(
1102 data: Cow<[u8]>,
1103 device: &Device,
1104 _comm: &Arc<crate::Comm>,
1105 guard: QuantizeOntoGuard,
1106 ) -> Result<Arc<dyn QuantMethod>>
1107 where
1108 Self: Sized,
1109 {
1110 let mut buffer = Cursor::new(data);
1111
1112 let version = buffer.read_u32::<LittleEndian>()?;
1113 if let Err(e) = version_is_compatible(version) {
1114 return Err(candle_core::Error::wrap(e));
1115 }
1116
1117 let isq_type = buffer.read_u8()? as usize;
1118 if isq_type != QuantizedSerdeType::Hqq as usize {
1119 candle_core::bail!(
1120 "ISQ type ({isq_type}) doesn't match expected type {}",
1121 QuantizedSerdeType::Hqq as usize
1122 );
1123 }
1124
1125 let has_bias = buffer.read_u8()? != 0;
1126
1127 let _acquired_load_guard = guard.acquire(device);
1128 let w_q = deserialize_tensor(&mut buffer, device)?;
1129 let scales = deserialize_tensor(&mut buffer, device)?;
1130 let zeros = deserialize_tensor(&mut buffer, device)?;
1131
1132 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
1133
1134 let mut dims = Vec::with_capacity(n_dims);
1135 for _ in 0..n_dims {
1136 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
1137 }
1138 let w_shape = Shape::from_dims(&dims);
1139
1140 let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
1142 let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
1143 let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
1144 let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
1145 0 => None,
1146 other => Some(other),
1147 };
1148 let round_zeros = buffer.read_u8()? != 0;
1149 let channel_wise = buffer.read_u8()? != 0;
1150
1151 let cfg = HqqConfig {
1152 bits,
1153 group_size,
1154 axis,
1155 optimization_steps,
1156 round_zeros,
1157 channel_wise,
1158 };
1159
1160 let b = if has_bias {
1161 Some(deserialize_tensor(&mut buffer, device)?)
1162 } else {
1163 None
1164 };
1165
1166 Ok(Arc::new(Self {
1167 w_q,
1168 zeros,
1169 scales,
1170 bias: b,
1171 w_shape,
1172 cfg,
1173 }))
1174 }
1175 fn deserialize_ext_bias(
1176 data: Cow<[u8]>,
1177 device: &Device,
1178 guard: QuantizeOntoGuard,
1179 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
1180 where
1181 Self: Sized,
1182 {
1183 let mut buffer = Cursor::new(data);
1184
1185 let version = buffer.read_u32::<LittleEndian>()?;
1186 if let Err(e) = version_is_compatible(version) {
1187 return Err(candle_core::Error::wrap(e));
1188 }
1189
1190 let isq_type = buffer.read_u8()? as usize;
1191 if isq_type != QuantizedSerdeType::Hqq as usize {
1192 candle_core::bail!(
1193 "ISQ type ({isq_type}) doesn't match expected type {}",
1194 QuantizedSerdeType::Hqq as usize
1195 );
1196 }
1197
1198 let has_bias = buffer.read_u8()? != 0;
1199
1200 let _acquired_load_guard = guard.acquire(device);
1201 let w_q = deserialize_tensor(&mut buffer, device)?;
1202 let scales = deserialize_tensor(&mut buffer, device)?;
1203 let zeros = deserialize_tensor(&mut buffer, device)?;
1204
1205 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
1206
1207 let mut dims = Vec::with_capacity(n_dims);
1208 for _ in 0..n_dims {
1209 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
1210 }
1211 let w_shape = Shape::from_dims(&dims);
1212
1213 let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
1215 let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
1216 let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
1217 let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
1218 0 => None,
1219 other => Some(other),
1220 };
1221 let round_zeros = buffer.read_u8()? != 0;
1222 let channel_wise = buffer.read_u8()? != 0;
1223
1224 let cfg = HqqConfig {
1225 bits,
1226 group_size,
1227 axis,
1228 optimization_steps,
1229 round_zeros,
1230 channel_wise,
1231 };
1232
1233 let b = if has_bias {
1234 Some(deserialize_tensor(&mut buffer, device)?)
1235 } else {
1236 None
1237 };
1238
1239 Ok((
1240 Arc::new(Self {
1241 w_q,
1242 zeros,
1243 scales,
1244 bias: None,
1245 w_shape,
1246 cfg,
1247 }),
1248 b,
1249 ))
1250 }
1251}
1252
1253impl HqqLayer {
1254 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
1255 let mut buffer = Cursor::new(data);
1256
1257 let version = buffer.read_u32::<LittleEndian>()?;
1258 if let Err(e) = version_is_compatible(version) {
1259 return Err(candle_core::Error::wrap(e));
1260 }
1261
1262 let isq_type = buffer.read_u8()? as usize;
1263 if isq_type != QuantizedSerdeType::Hqq as usize {
1264 candle_core::bail!(
1265 "ISQ type ({isq_type}) doesn't match expected type {}",
1266 QuantizedSerdeType::Hqq as usize
1267 );
1268 }
1269
1270 let _has_bias = buffer.read_u8()? != 0;
1271
1272 fake_deserialize_tensor(&mut buffer)?;
1273 fake_deserialize_tensor(&mut buffer)?;
1274 fake_deserialize_tensor(&mut buffer)?;
1275
1276 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
1277
1278 let mut dims = Vec::with_capacity(n_dims);
1279 for _ in 0..n_dims {
1280 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
1281 }
1282 let _w_shape = Shape::from_dims(&dims);
1283
1284 let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
1286
1287 match bits {
1288 HqqBits::Eight => Ok(IsqType::HQQ8),
1289 HqqBits::Four => Ok(IsqType::HQQ4),
1290 HqqBits::One | HqqBits::Two | HqqBits::Three => {
1291 candle_core::bail!("cannot convert hqq bits to isq type")
1292 }
1293 }
1294 }
1295}