1use byteorder::{LittleEndian, ReadBytesExt};
2use candle_core::{DType, Device, Result, Shape, Tensor};
3
4#[cfg(feature = "cuda")]
5use candle_core::{
6 cuda::{cudarc::driver::DevicePtr, CudaStorageSlice, WrapErr},
7 from_storage_no_op, CudaStorage, Storage,
8};
9
10use candle_nn::Linear;
11#[cfg(feature = "cuda")]
12use half::{bf16, f16};
13use std::{
14 borrow::Cow,
15 io::Cursor,
16 num::NonZeroUsize,
17 sync::{atomic::AtomicUsize, Arc},
18};
19
20use crate::{
21 utils::{
22 deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
23 BitWiseOp, LeftshiftOp, UQFF_VERSION,
24 },
25 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
26 UnquantLinear,
27};
28
29#[cfg(feature = "cuda")]
30use crate::utils::{get_cuda_device, get_cuda_slice};
31
32#[cfg(feature = "cuda")]
33use ffi::{eight_bit, four_bit, one_bit, three_bit, two_bit};
34
35#[cfg(feature = "cuda")]
36mod ffi;
37
38#[cfg(not(feature = "cuda"))]
39mod hqq_op;
40
41mod optimize;
42mod quantize;
43
44pub(crate) const ISQ_HQQ_GROUP_SIZE: usize = 64;
45pub(crate) const ISQ_HQQ_DEFAULT_OPT_STEPS: Option<usize> = Some(10);
46pub(crate) const OPTIMIZER_HQQ_DEFAULT_STEPS: usize = 20;
47
48#[cfg(feature = "cuda")]
49macro_rules! dequant_for_dtype {
50 ($this:expr, w=$wq_t:ty, sz=$scale_t:ty, $dtype:ident, pack=$pack:expr, $dev:expr, $bit_thing:ident, $postfix:tt) => {{
51 paste::paste! {
52 let w_slice = get_cuda_slice::<$wq_t>(&$this.w_q)?;
53 let scale_slice = get_cuda_slice::<$scale_t>(&$this.scales)?;
54 let zero_slice = get_cuda_slice::<$scale_t>(&$this.zeros)?;
55
56 let (h, w) = $this.w_q.dims2()?;
57 let num_packed_elems = $pack;
58 let out_shape = Shape::from_dims(&[num_packed_elems * h, w]);
59
60 let out = unsafe { $dev.alloc::<$scale_t>(out_shape.elem_count()).w()? };
61 let out_ptr = *out.device_ptr() as *mut $scale_t;
62 unsafe {
63 $bit_thing::[< dequantize_ $postfix >](
64 w_slice,
65 scale_slice,
66 zero_slice,
67 out_ptr,
68 h as i32,
69 w as i32,
70 );
71 }
72
73 let storage = CudaStorage {
74 slice: CudaStorageSlice::$dtype(out),
75 device: $dev.clone(),
76 };
77 let storage = Storage::Cuda(storage);
78
79 from_storage_no_op(storage, out_shape, false)
80 }
81 }};
82}
83
84#[derive(Debug, Clone, Copy)]
85pub enum HqqAxis {
86 Zero = 0,
87 One = 1,
88}
89
90impl TryFrom<usize> for HqqAxis {
91 type Error = candle_core::Error;
92 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
93 match value {
94 0 => Ok(Self::Zero),
95 1 => Ok(Self::One),
96 other => candle_core::bail!("Unexpected value for HQQ axis {other}"),
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy)]
102pub enum HqqBits {
103 Eight = 8,
104 Four = 4,
105 Three = 3,
106 Two = 2,
107 One = 1,
108}
109
110impl TryFrom<usize> for HqqBits {
111 type Error = candle_core::Error;
112 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
113 match value {
114 8 => Ok(Self::Eight),
115 4 => Ok(Self::Four),
116 3 => Ok(Self::Three),
117 2 => Ok(Self::Two),
118 1 => Ok(Self::One),
119 other => candle_core::bail!("Unexpected value for HQQ bits {other}"),
120 }
121 }
122}
123
124impl HqqBits {
125 pub(crate) fn bitpack_type(&self) -> impl Fn(Tensor) -> Result<Tensor> {
127 match self {
128 Self::Eight => |wq: Tensor| wq.to_dtype(DType::U8),
129 Self::Four => |wq: Tensor| {
130 let wq = wq.to_dtype(DType::U8)?;
131 let step = (wq.dims()[0] as f64 / 2.) as usize;
132
133 let a = wq.narrow(0, 0, step)?;
134 let b = wq.narrow(0, step, step)?;
135 a.leftshift(4)?.bitwise_or(&b)
136 },
137 Self::Two => |wq: Tensor| {
138 let wq = wq.to_dtype(DType::U8)?;
139 let step = (wq.dims()[0] as f64 / 4.) as usize;
140
141 let a = wq.narrow(0, 0, step)?;
142 let b = wq.narrow(0, step, step)?;
143 let c = wq.narrow(0, step * 2, step)?;
144 let d = wq.narrow(0, step * 3, step)?;
145
146 a.leftshift(6)?
147 .bitwise_or(&b.leftshift(4)?)?
148 .bitwise_or(&c.leftshift(2)?)?
149 .bitwise_or(&d)
150 },
151 Self::Three => |wq_in: Tensor| {
152 let wq = Tensor::zeros(
153 (
154 (10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize,
155 wq_in.dims()[1],
156 ),
157 DType::U32,
158 wq_in.device(),
159 )?;
160 let wq = wq
161 .slice_assign(&[&(..wq_in.dims()[0]), &..], &wq_in.to_dtype(DType::U32)?)?
162 .to_dtype(DType::I32)?;
163 let step = (wq.dims()[0] as f64 / 10.) as usize;
164
165 let a = wq.narrow(0, 0, step)?;
166 let b = wq.narrow(0, step, step)?;
167 let c = wq.narrow(0, step * 2, step)?;
168 let d = wq.narrow(0, step * 3, step)?;
169 let e = wq.narrow(0, step * 4, step)?;
170 let f = wq.narrow(0, step * 5, step)?;
171 let g = wq.narrow(0, step * 6, step)?;
172 let h = wq.narrow(0, step * 7, step)?;
173 let i = wq.narrow(0, step * 8, step)?;
174 let j = wq.narrow(0, step * 9, step)?;
175
176 a.leftshift(27)
177 .unwrap()
178 .bitwise_or(&b.leftshift(24).unwrap())
179 .unwrap()
180 .bitwise_or(&c.leftshift(21)?)?
181 .bitwise_or(&d.leftshift(18)?)?
182 .bitwise_or(&e.leftshift(15)?)?
183 .bitwise_or(&f.leftshift(12)?)?
184 .bitwise_or(&g.leftshift(9)?)?
185 .bitwise_or(&h.leftshift(6)?)?
186 .bitwise_or(&i.leftshift(3)?)?
187 .bitwise_or(&j)
188 },
189 Self::One => |wq: Tensor| {
190 let wq = wq.to_dtype(DType::U8)?;
191 let step = (wq.dims()[0] as f64 / 8.) as usize;
192
193 let a = wq.narrow(0, 0, step)?;
194 let b = wq.narrow(0, step, step)?;
195 let c = wq.narrow(0, step * 2, step)?;
196 let d = wq.narrow(0, step * 3, step)?;
197 let e = wq.narrow(0, step * 4, step)?;
198 let f = wq.narrow(0, step * 5, step)?;
199 let g = wq.narrow(0, step * 6, step)?;
200 let h = wq.narrow(0, step * 7, step)?;
201
202 a.leftshift(7)?
203 .bitwise_or(&b.leftshift(6)?)?
204 .bitwise_or(&c.leftshift(5)?)?
205 .bitwise_or(&d.leftshift(4)?)?
206 .bitwise_or(&e.leftshift(3)?)?
207 .bitwise_or(&f.leftshift(2)?)?
208 .bitwise_or(&g.leftshift(1)?)?
209 .bitwise_or(&h)
210 },
211 }
212 }
213}
214
215#[derive(Debug, Clone, Copy)]
216pub struct HqqConfig {
217 pub bits: HqqBits,
218 pub group_size: NonZeroUsize,
219 pub axis: HqqAxis,
220 pub optimization_steps: Option<usize>,
221 pub round_zeros: bool, pub channel_wise: bool, }
224
225#[derive(Debug)]
226pub struct HqqLayer {
227 pub(crate) w_q: Tensor,
228 pub(crate) zeros: Tensor,
229 pub(crate) scales: Tensor,
230 pub(crate) bias: Option<Tensor>,
231 pub(crate) w_shape: Shape,
232 pub(crate) cfg: HqqConfig,
233}
234
235impl HqqLayer {
236 #[cfg(not(feature = "cuda"))]
238 fn dequantize(&self) -> Result<Tensor> {
239 use crate::hqq::hqq_op::{Dequant1Bit, Dequant2Bit, Dequant3Bit, Dequant4Bit, Dequant8Bit};
240
241 match (self.scales.dtype(), self.zeros.dtype()) {
242 (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
243 (a, b) => {
244 candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
245 }
246 }
247 if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
248 {
249 candle_core::bail!("All tensors must be contiguous!");
250 }
251 if self.cfg.axis as usize != 0 {
252 candle_core::bail!(
253 "CPU HQQ dequantization requires axis == 0, got {}.",
254 self.cfg.axis as usize
255 );
256 }
257 let (h, w) = self.w_q.dims2()?;
258
259 match self.cfg.bits as usize {
260 8 => self
261 .w_q
262 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant8Bit { h, w })?
263 .reshape(&self.w_shape),
264 4 => self
265 .w_q
266 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant4Bit { h, w })?
267 .reshape(&self.w_shape),
268 3 => self
269 .w_q
270 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant3Bit { h, w })?
271 .reshape(&self.w_shape),
272 2 => self
273 .w_q
274 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant2Bit { h, w })?
275 .reshape(&self.w_shape),
276 1 => self
277 .w_q
278 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant1Bit { h, w })?
279 .reshape(&self.w_shape),
280 b => candle_core::bail!("Unreachable bits {b}"),
281 }
282 }
283
284 #[cfg(feature = "cuda")]
286 fn dequantize(&self) -> Result<Tensor> {
287 match (self.scales.dtype(), self.zeros.dtype()) {
288 (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
289 (a, b) => {
290 candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
291 }
292 }
293 if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
294 {
295 candle_core::bail!("All tensors must be contiguous!");
296 }
297 if self.cfg.axis as usize != 0 {
298 candle_core::bail!(
299 "CUDA HQQ dequantization requires axis == 0, got {}.",
300 self.cfg.axis as usize
301 );
302 }
303 let dev = get_cuda_device(&self.w_q)?;
304
305 let inner = match (self.cfg.bits as usize, self.scales.dtype()) {
306 (8, DType::F32) => {
308 dequant_for_dtype!(
309 self,
310 w = u8,
311 sz = f32,
312 F32,
313 pack = 1,
314 dev,
315 eight_bit,
316 8bit_u8_kernel_f32
317 )
318 }
319 (8, DType::F16) => {
320 dequant_for_dtype!(
321 self,
322 w = u8,
323 sz = f16,
324 F16,
325 pack = 1,
326 dev,
327 eight_bit,
328 8bit_u8_kernel_f16
329 )
330 }
331 (8, DType::BF16) => {
332 dequant_for_dtype!(
333 self,
334 w = u8,
335 sz = bf16,
336 BF16,
337 pack = 1,
338 dev,
339 eight_bit,
340 8bit_u8_kernel_bf16
341 )
342 }
343
344 (4, DType::F32) => {
346 dequant_for_dtype!(
347 self,
348 w = u8,
349 sz = f32,
350 F32,
351 pack = 2,
352 dev,
353 four_bit,
354 4bit_u8_kernel_f32
355 )
356 }
357 (4, DType::F16) => {
358 dequant_for_dtype!(
359 self,
360 w = u8,
361 sz = f16,
362 F16,
363 pack = 2,
364 dev,
365 four_bit,
366 4bit_u8_kernel_f16
367 )
368 }
369 (4, DType::BF16) => {
370 dequant_for_dtype!(
371 self,
372 w = u8,
373 sz = bf16,
374 BF16,
375 pack = 2,
376 dev,
377 four_bit,
378 4bit_u8_kernel_bf16
379 )
380 }
381
382 (3, DType::F32) => {
385 let res = dequant_for_dtype!(
386 self,
387 w = i32,
388 sz = f32,
389 F32,
390 pack = 10,
391 dev,
392 three_bit,
393 3bit_32_kernel_f32
394 );
395 res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
396 }
397 (3, DType::F16) => {
398 let res = dequant_for_dtype!(
399 self,
400 w = i32,
401 sz = f16,
402 F16,
403 pack = 10,
404 dev,
405 three_bit,
406 3bit_32_kernel_f16
407 );
408 res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
409 }
410 (3, DType::BF16) => {
411 let res = dequant_for_dtype!(
412 self,
413 w = i32,
414 sz = bf16,
415 BF16,
416 pack = 10,
417 dev,
418 three_bit,
419 3bit_32_kernel_bf16
420 );
421 res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
422 }
423
424 (2, DType::F32) => {
426 dequant_for_dtype!(
427 self,
428 w = u8,
429 sz = f32,
430 F32,
431 pack = 4,
432 dev,
433 two_bit,
434 2bit_u8_kernel_f32
435 )
436 }
437 (2, DType::F16) => {
438 dequant_for_dtype!(
439 self,
440 w = u8,
441 sz = f16,
442 F16,
443 pack = 4,
444 dev,
445 two_bit,
446 2bit_u8_kernel_f16
447 )
448 }
449 (2, DType::BF16) => {
450 dequant_for_dtype!(
451 self,
452 w = u8,
453 sz = bf16,
454 BF16,
455 pack = 4,
456 dev,
457 two_bit,
458 2bit_u8_kernel_bf16
459 )
460 }
461
462 (1, DType::F32) => {
464 dequant_for_dtype!(
465 self,
466 w = u8,
467 sz = f32,
468 F32,
469 pack = 8,
470 dev,
471 one_bit,
472 1bit_u8_kernel_f32
473 )
474 }
475 (1, DType::F16) => {
476 dequant_for_dtype!(
477 self,
478 w = u8,
479 sz = f16,
480 F16,
481 pack = 8,
482 dev,
483 one_bit,
484 1bit_u8_kernel_f16
485 )
486 }
487 (1, DType::BF16) => {
488 dequant_for_dtype!(
489 self,
490 w = u8,
491 sz = bf16,
492 BF16,
493 pack = 8,
494 dev,
495 one_bit,
496 1bit_u8_kernel_bf16
497 )
498 }
499 (bits, dtype) => candle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}"),
500 };
501 inner.reshape(&self.w_shape)
502 }
503
504 fn dequantize_matmul(&self, xs: &Tensor) -> Result<Tensor> {
505 let w = self.dequantize()?;
506 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
508 w,
509 self.bias.clone(),
510 )))?;
511 unquant.forward(xs)
512 }
513
514 pub fn with_bias(mut self, bias: Tensor) -> Self {
515 self.bias = Some(bias);
516 self
517 }
518}
519
520impl QuantMethod for HqqLayer {
521 fn new(method: QuantMethodConfig) -> Result<Self>
522 where
523 Self: Sized,
524 {
525 match method {
526 QuantMethodConfig::Gguf { .. }
527 | QuantMethodConfig::Unquantized(_)
528 | QuantMethodConfig::GptqAwq { .. }
529 | QuantMethodConfig::Dummy
530 | QuantMethodConfig::FP8 { .. }
531 | QuantMethodConfig::Bnb { .. }
532 | QuantMethodConfig::BlockwiseFP8 { .. }
533 | QuantMethodConfig::Afq { .. } => {
534 unreachable!()
535 }
536 QuantMethodConfig::Hqq {
537 tensor,
538 bits,
539 group_size,
540 axis,
541 optimization_steps,
542 round_zeros,
543 channel_wise,
544 bias,
545 } => {
546 let cfg = HqqConfig {
547 bits,
548 group_size,
549 axis,
550 optimization_steps,
551 round_zeros: round_zeros.unwrap_or(false),
552 channel_wise: channel_wise.unwrap_or(true),
553 };
554
555 let this = Self::quantize(&tensor, tensor.device(), cfg)?;
556 if let Some(bias) = bias {
557 Ok(this.with_bias(bias))
558 } else {
559 Ok(this)
560 }
561 }
562 }
563 }
564
565 fn dequantize_w(&self) -> Result<Tensor> {
566 self.dequantize()
567 }
568
569 fn forward(&self, a: &Tensor) -> Result<Tensor> {
570 self.dequantize_matmul(a)
577 }
578
579 fn quantized_act_type(&self) -> Option<DType> {
580 Some(self.scales.dtype())
581 }
582
583 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
584 candle_core::bail!("HQQ quantization does not support adding weight delta.")
585 }
586
587 fn dtype_and_device(&self) -> (DType, Device) {
588 (self.scales.dtype(), self.scales.device().clone())
589 }
590
591 fn apply_isq(
592 self: Arc<Self>,
593 dtype: Option<IsqType>,
594 device: Device,
595 n_quantized: &AtomicUsize,
596 imatrix_weight: Option<Vec<f32>>,
597 guard: QuantizeOntoGuard,
598 ) -> Result<Arc<dyn QuantMethod>> {
599 let _acquired_quantize_guard = guard.acquire(&device);
600 if imatrix_weight.is_some() {
601 candle_core::bail!("HQQ does not support imatrix.");
603 }
604
605 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
606 let bits = match dtype {
607 Some(IsqType::HQQ8) => HqqBits::Eight,
608 Some(IsqType::HQQ4) => HqqBits::Four,
609 _ => candle_core::bail!("Expected a HQQ ISQ type."),
613 };
614 let cfg = HqqConfig {
615 bits,
616 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
617 axis: HqqAxis::Zero,
618 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
619 round_zeros: false,
620 channel_wise: true,
621 };
622 let dequant = self.dequantize()?;
623 let res = Self::quantize(&dequant, &device, cfg)?;
624 if let Some(ref bias) = self.bias {
625 let bias = bias
626 .to_device(&device)?
627 .to_dtype(res.dtype_and_device().0)?;
628 Ok(Arc::new(res.with_bias(bias)))
629 } else {
630 Ok(Arc::new(res))
631 }
632 }
633}
634
635impl QuantizedSerde for HqqLayer {
672 fn isq_serde_supported(&self) -> bool {
673 true
674 }
675 fn name(&self) -> &'static str {
676 "hqq"
677 }
678 fn serialize(&self) -> Result<Cow<[u8]>> {
679 self.serialize_with_bias(self.bias.clone())
680 }
681 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
682 let mut buffer = Vec::new();
683
684 buffer.extend(&UQFF_VERSION.to_le_bytes());
686
687 buffer.push(QuantizedSerdeType::Hqq as u8);
689
690 buffer.push(bias.is_some() as u8);
692
693 serialize_tensor(&mut buffer, &self.w_q)?;
694 serialize_tensor(&mut buffer, &self.scales)?;
695 serialize_tensor(&mut buffer, &self.zeros)?;
696
697 let w_shape = self.w_shape.dims();
698 buffer.extend((w_shape.len() as u32).to_le_bytes());
699 for dim in w_shape {
700 buffer.extend((*dim as u32).to_le_bytes());
701 }
702
703 buffer.push(self.cfg.bits as u8);
705 buffer.extend(
706 &(<NonZeroUsize as Into<usize>>::into(self.cfg.group_size) as u32).to_le_bytes(),
707 );
708 buffer.push(self.cfg.axis as u8);
709 buffer.extend(&(self.cfg.optimization_steps.unwrap_or(0) as u32).to_le_bytes());
711 buffer.push(self.cfg.round_zeros as u8);
712 buffer.push(self.cfg.channel_wise as u8);
713
714 if let Some(bias) = &bias {
715 serialize_tensor(&mut buffer, bias)?;
717 }
718
719 Ok(Cow::from(buffer))
720 }
721
722 fn deserialize(
723 data: Cow<[u8]>,
724 device: &Device,
725 _comm: &Arc<crate::Comm>,
726 guard: QuantizeOntoGuard,
727 ) -> Result<Arc<dyn QuantMethod>>
728 where
729 Self: Sized,
730 {
731 let mut buffer = Cursor::new(data);
732
733 let version = buffer.read_u32::<LittleEndian>()?;
734 if let Err(e) = version_is_compatible(version) {
735 return Err(candle_core::Error::wrap(e));
736 }
737
738 let isq_type = buffer.read_u8()? as usize;
739 if isq_type != QuantizedSerdeType::Hqq as usize {
740 candle_core::bail!(
741 "ISQ type ({isq_type}) doesn't match expected type {}",
742 QuantizedSerdeType::Hqq as usize
743 );
744 }
745
746 let has_bias = buffer.read_u8()? != 0;
747
748 let _acquired_load_guard = guard.acquire(device);
749 let w_q = deserialize_tensor(&mut buffer, device)?;
750 let scales = deserialize_tensor(&mut buffer, device)?;
751 let zeros = deserialize_tensor(&mut buffer, device)?;
752
753 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
754
755 let mut dims = Vec::with_capacity(n_dims);
756 for _ in 0..n_dims {
757 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
758 }
759 let w_shape = Shape::from_dims(&dims);
760
761 let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
763 let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
764 let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
765 let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
766 0 => None,
767 other => Some(other),
768 };
769 let round_zeros = buffer.read_u8()? != 0;
770 let channel_wise = buffer.read_u8()? != 0;
771
772 let cfg = HqqConfig {
773 bits,
774 group_size,
775 axis,
776 optimization_steps,
777 round_zeros,
778 channel_wise,
779 };
780
781 let b = if has_bias {
782 Some(deserialize_tensor(&mut buffer, device)?)
783 } else {
784 None
785 };
786
787 Ok(Arc::new(Self {
788 w_q,
789 zeros,
790 scales,
791 bias: b,
792 w_shape,
793 cfg,
794 }))
795 }
796 fn deserialize_ext_bias(
797 data: Cow<[u8]>,
798 device: &Device,
799 guard: QuantizeOntoGuard,
800 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
801 where
802 Self: Sized,
803 {
804 let mut buffer = Cursor::new(data);
805
806 let version = buffer.read_u32::<LittleEndian>()?;
807 if let Err(e) = version_is_compatible(version) {
808 return Err(candle_core::Error::wrap(e));
809 }
810
811 let isq_type = buffer.read_u8()? as usize;
812 if isq_type != QuantizedSerdeType::Hqq as usize {
813 candle_core::bail!(
814 "ISQ type ({isq_type}) doesn't match expected type {}",
815 QuantizedSerdeType::Hqq as usize
816 );
817 }
818
819 let has_bias = buffer.read_u8()? != 0;
820
821 let _acquired_load_guard = guard.acquire(device);
822 let w_q = deserialize_tensor(&mut buffer, device)?;
823 let scales = deserialize_tensor(&mut buffer, device)?;
824 let zeros = deserialize_tensor(&mut buffer, device)?;
825
826 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
827
828 let mut dims = Vec::with_capacity(n_dims);
829 for _ in 0..n_dims {
830 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
831 }
832 let w_shape = Shape::from_dims(&dims);
833
834 let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
836 let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
837 let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
838 let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
839 0 => None,
840 other => Some(other),
841 };
842 let round_zeros = buffer.read_u8()? != 0;
843 let channel_wise = buffer.read_u8()? != 0;
844
845 let cfg = HqqConfig {
846 bits,
847 group_size,
848 axis,
849 optimization_steps,
850 round_zeros,
851 channel_wise,
852 };
853
854 let b = if has_bias {
855 Some(deserialize_tensor(&mut buffer, device)?)
856 } else {
857 None
858 };
859
860 Ok((
861 Arc::new(Self {
862 w_q,
863 zeros,
864 scales,
865 bias: None,
866 w_shape,
867 cfg,
868 }),
869 b,
870 ))
871 }
872}
873
874impl HqqLayer {
875 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
876 let mut buffer = Cursor::new(data);
877
878 let version = buffer.read_u32::<LittleEndian>()?;
879 if let Err(e) = version_is_compatible(version) {
880 return Err(candle_core::Error::wrap(e));
881 }
882
883 let isq_type = buffer.read_u8()? as usize;
884 if isq_type != QuantizedSerdeType::Hqq as usize {
885 candle_core::bail!(
886 "ISQ type ({isq_type}) doesn't match expected type {}",
887 QuantizedSerdeType::Hqq as usize
888 );
889 }
890
891 let _has_bias = buffer.read_u8()? != 0;
892
893 fake_deserialize_tensor(&mut buffer)?;
894 fake_deserialize_tensor(&mut buffer)?;
895 fake_deserialize_tensor(&mut buffer)?;
896
897 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
898
899 let mut dims = Vec::with_capacity(n_dims);
900 for _ in 0..n_dims {
901 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
902 }
903 let _w_shape = Shape::from_dims(&dims);
904
905 let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
907
908 match bits {
909 HqqBits::Eight => Ok(IsqType::HQQ8),
910 HqqBits::Four => Ok(IsqType::HQQ4),
911 HqqBits::One | HqqBits::Two | HqqBits::Three => {
912 candle_core::bail!("cannot convert hqq bits to isq type")
913 }
914 }
915 }
916}