1use byteorder::{LittleEndian, ReadBytesExt};
2use candle_core::{DType, Device, Result, Shape, Tensor};
3
4#[cfg(feature = "cuda")]
5use candle_core::{
6 cuda::{cudarc::driver::DevicePtr, CudaStorageSlice, WrapErr},
7 from_storage_no_op, CudaStorage, Storage,
8};
9
10use candle_nn::Linear;
11#[cfg(feature = "cuda")]
12use half::{bf16, f16};
13use std::{
14 borrow::Cow,
15 io::Cursor,
16 num::NonZeroUsize,
17 sync::{atomic::AtomicUsize, Arc},
18};
19
20use crate::{
21 utils::{
22 deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
23 BitWiseOp, LeftshiftOp, UQFF_VERSION,
24 },
25 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
26 UnquantLinear,
27};
28
29#[cfg(feature = "cuda")]
30use crate::utils::{get_cuda_device, get_cuda_slice};
31
32#[cfg(feature = "cuda")]
33use ffi::{eight_bit, four_bit, one_bit, three_bit, two_bit};
34
35#[cfg(feature = "cuda")]
36mod ffi;
37
38#[cfg(not(feature = "cuda"))]
39mod hqq_op;
40
41mod optimize;
42mod quantize;
43
44pub(crate) const ISQ_HQQ_GROUP_SIZE: usize = 64;
45pub(crate) const ISQ_HQQ_DEFAULT_OPT_STEPS: Option<usize> = Some(10);
46pub(crate) const OPTIMIZER_HQQ_DEFAULT_STEPS: usize = 20;
47
48#[cfg(feature = "cuda")]
49macro_rules! dequant_for_dtype {
50 ($this:expr, w=$wq_t:ty, sz=$scale_t:ty, $dtype:ident, pack=$pack:expr, $dev:expr, $bit_thing:ident, $postfix:tt) => {{
51 paste::paste! {
52 let w_slice = get_cuda_slice::<$wq_t>(&$this.w_q)?;
53 let scale_slice = get_cuda_slice::<$scale_t>(&$this.scales)?;
54 let zero_slice = get_cuda_slice::<$scale_t>(&$this.zeros)?;
55
56 let (h, w) = $this.w_q.dims2()?;
57 let num_packed_elems = $pack;
58 let out_shape = Shape::from_dims(&[num_packed_elems * h, w]);
59
60 let out = unsafe { $dev.alloc::<$scale_t>(out_shape.elem_count()).w()? };
61 let out_ptr = *out.device_ptr() as *mut $scale_t;
62 unsafe {
63 $bit_thing::[< dequantize_ $postfix >](
64 w_slice,
65 scale_slice,
66 zero_slice,
67 out_ptr,
68 h as i32,
69 w as i32,
70 );
71 }
72
73 let storage = CudaStorage {
74 slice: CudaStorageSlice::$dtype(out),
75 device: $dev.clone(),
76 };
77 let storage = Storage::Cuda(storage);
78
79 from_storage_no_op(storage, out_shape, false)
80 }
81 }};
82}
83
84#[derive(Debug, Clone, Copy)]
85pub enum HqqAxis {
86 Zero = 0,
87 One = 1,
88}
89
90impl TryFrom<usize> for HqqAxis {
91 type Error = candle_core::Error;
92 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
93 match value {
94 0 => Ok(Self::Zero),
95 1 => Ok(Self::One),
96 other => candle_core::bail!("Unexpected value for HQQ axis {other}"),
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy)]
102pub enum HqqBits {
103 Eight = 8,
104 Four = 4,
105 Three = 3,
106 Two = 2,
107 One = 1,
108}
109
110impl TryFrom<usize> for HqqBits {
111 type Error = candle_core::Error;
112 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
113 match value {
114 8 => Ok(Self::Eight),
115 4 => Ok(Self::Four),
116 3 => Ok(Self::Three),
117 2 => Ok(Self::Two),
118 1 => Ok(Self::One),
119 other => candle_core::bail!("Unexpected value for HQQ bits {other}"),
120 }
121 }
122}
123
124impl HqqBits {
125 pub(crate) fn bitpack_type(&self) -> impl Fn(Tensor) -> Result<Tensor> {
127 match self {
128 Self::Eight => |wq: Tensor| wq.to_dtype(DType::U8),
129 Self::Four => |wq: Tensor| {
130 let wq = wq.to_dtype(DType::U8)?;
131 let step = (wq.dims()[0] as f64 / 2.) as usize;
132
133 let a = wq.narrow(0, 0, step)?;
134 let b = wq.narrow(0, step, step)?;
135 a.leftshift(4)?.bitwise_or(&b)
136 },
137 Self::Two => |wq: Tensor| {
138 let wq = wq.to_dtype(DType::U8)?;
139 let step = (wq.dims()[0] as f64 / 4.) as usize;
140
141 let a = wq.narrow(0, 0, step)?;
142 let b = wq.narrow(0, step, step)?;
143 let c = wq.narrow(0, step * 2, step)?;
144 let d = wq.narrow(0, step * 3, step)?;
145
146 a.leftshift(6)?
147 .bitwise_or(&b.leftshift(4)?)?
148 .bitwise_or(&c.leftshift(2)?)?
149 .bitwise_or(&d)
150 },
151 Self::Three => |wq_in: Tensor| {
152 let wq = Tensor::zeros(
153 (
154 (10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize,
155 wq_in.dims()[1],
156 ),
157 DType::U32,
158 wq_in.device(),
159 )?;
160 let wq = wq
161 .slice_assign(&[&(..wq_in.dims()[0]), &..], &wq_in.to_dtype(DType::U32)?)?
162 .to_dtype(DType::I32)?;
163 let step = (wq.dims()[0] as f64 / 10.) as usize;
164
165 let a = wq.narrow(0, 0, step)?;
166 let b = wq.narrow(0, step, step)?;
167 let c = wq.narrow(0, step * 2, step)?;
168 let d = wq.narrow(0, step * 3, step)?;
169 let e = wq.narrow(0, step * 4, step)?;
170 let f = wq.narrow(0, step * 5, step)?;
171 let g = wq.narrow(0, step * 6, step)?;
172 let h = wq.narrow(0, step * 7, step)?;
173 let i = wq.narrow(0, step * 8, step)?;
174 let j = wq.narrow(0, step * 9, step)?;
175
176 a.leftshift(27)?
177 .bitwise_or(&b.leftshift(24)?)?
178 .bitwise_or(&c.leftshift(21)?)?
179 .bitwise_or(&d.leftshift(18)?)?
180 .bitwise_or(&e.leftshift(15)?)?
181 .bitwise_or(&f.leftshift(12)?)?
182 .bitwise_or(&g.leftshift(9)?)?
183 .bitwise_or(&h.leftshift(6)?)?
184 .bitwise_or(&i.leftshift(3)?)?
185 .bitwise_or(&j)
186 },
187 Self::One => |wq: Tensor| {
188 let wq = wq.to_dtype(DType::U8)?;
189 let step = (wq.dims()[0] as f64 / 8.) as usize;
190
191 let a = wq.narrow(0, 0, step)?;
192 let b = wq.narrow(0, step, step)?;
193 let c = wq.narrow(0, step * 2, step)?;
194 let d = wq.narrow(0, step * 3, step)?;
195 let e = wq.narrow(0, step * 4, step)?;
196 let f = wq.narrow(0, step * 5, step)?;
197 let g = wq.narrow(0, step * 6, step)?;
198 let h = wq.narrow(0, step * 7, step)?;
199
200 a.leftshift(7)?
201 .bitwise_or(&b.leftshift(6)?)?
202 .bitwise_or(&c.leftshift(5)?)?
203 .bitwise_or(&d.leftshift(4)?)?
204 .bitwise_or(&e.leftshift(3)?)?
205 .bitwise_or(&f.leftshift(2)?)?
206 .bitwise_or(&g.leftshift(1)?)?
207 .bitwise_or(&h)
208 },
209 }
210 }
211}
212
213#[derive(Debug, Clone, Copy)]
214pub struct HqqConfig {
215 pub bits: HqqBits,
216 pub group_size: NonZeroUsize,
217 pub axis: HqqAxis,
218 pub optimization_steps: Option<usize>,
219 pub round_zeros: bool, pub channel_wise: bool, }
222
223#[derive(Debug)]
224pub struct HqqLayer {
225 pub(crate) w_q: Tensor,
226 pub(crate) zeros: Tensor,
227 pub(crate) scales: Tensor,
228 pub(crate) bias: Option<Tensor>,
229 pub(crate) w_shape: Shape,
230 pub(crate) cfg: HqqConfig,
231}
232
233impl HqqLayer {
234 #[cfg(not(feature = "cuda"))]
236 fn dequantize(&self) -> Result<Tensor> {
237 use crate::hqq::hqq_op::{Dequant1Bit, Dequant2Bit, Dequant3Bit, Dequant4Bit, Dequant8Bit};
238
239 match (self.scales.dtype(), self.zeros.dtype()) {
240 (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
241 (a, b) => {
242 candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
243 }
244 }
245 if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
246 {
247 candle_core::bail!("All tensors must be contiguous!");
248 }
249 if self.cfg.axis as usize != 0 {
250 candle_core::bail!(
251 "CPU HQQ dequantization requires axis == 0, got {}.",
252 self.cfg.axis as usize
253 );
254 }
255 let (h, w) = self.w_q.dims2()?;
256
257 match self.cfg.bits as usize {
258 8 => self
259 .w_q
260 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant8Bit { h, w })?
261 .reshape(&self.w_shape),
262 4 => self
263 .w_q
264 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant4Bit { h, w })?
265 .reshape(&self.w_shape),
266 3 => self
267 .w_q
268 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant3Bit { h, w })?
269 .reshape(&self.w_shape),
270 2 => self
271 .w_q
272 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant2Bit { h, w })?
273 .reshape(&self.w_shape),
274 1 => self
275 .w_q
276 .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant1Bit { h, w })?
277 .reshape(&self.w_shape),
278 b => candle_core::bail!("Unreachable bits {b}"),
279 }
280 }
281
282 #[cfg(feature = "cuda")]
284 fn dequantize(&self) -> Result<Tensor> {
285 match (self.scales.dtype(), self.zeros.dtype()) {
286 (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
287 (a, b) => {
288 candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
289 }
290 }
291 if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
292 {
293 candle_core::bail!("All tensors must be contiguous!");
294 }
295 if self.cfg.axis as usize != 0 {
296 candle_core::bail!(
297 "CUDA HQQ dequantization requires axis == 0, got {}.",
298 self.cfg.axis as usize
299 );
300 }
301 let dev = get_cuda_device(&self.w_q)?;
302
303 let inner = match (self.cfg.bits as usize, self.scales.dtype()) {
304 (8, DType::F32) => {
306 dequant_for_dtype!(
307 self,
308 w = u8,
309 sz = f32,
310 F32,
311 pack = 1,
312 dev,
313 eight_bit,
314 8bit_u8_kernel_f32
315 )
316 }
317 (8, DType::F16) => {
318 dequant_for_dtype!(
319 self,
320 w = u8,
321 sz = f16,
322 F16,
323 pack = 1,
324 dev,
325 eight_bit,
326 8bit_u8_kernel_f16
327 )
328 }
329 (8, DType::BF16) => {
330 dequant_for_dtype!(
331 self,
332 w = u8,
333 sz = bf16,
334 BF16,
335 pack = 1,
336 dev,
337 eight_bit,
338 8bit_u8_kernel_bf16
339 )
340 }
341
342 (4, DType::F32) => {
344 dequant_for_dtype!(
345 self,
346 w = u8,
347 sz = f32,
348 F32,
349 pack = 2,
350 dev,
351 four_bit,
352 4bit_u8_kernel_f32
353 )
354 }
355 (4, DType::F16) => {
356 dequant_for_dtype!(
357 self,
358 w = u8,
359 sz = f16,
360 F16,
361 pack = 2,
362 dev,
363 four_bit,
364 4bit_u8_kernel_f16
365 )
366 }
367 (4, DType::BF16) => {
368 dequant_for_dtype!(
369 self,
370 w = u8,
371 sz = bf16,
372 BF16,
373 pack = 2,
374 dev,
375 four_bit,
376 4bit_u8_kernel_bf16
377 )
378 }
379
380 (3, DType::F32) => {
383 let res = dequant_for_dtype!(
384 self,
385 w = i32,
386 sz = f32,
387 F32,
388 pack = 10,
389 dev,
390 three_bit,
391 3bit_32_kernel_f32
392 );
393 res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
394 }
395 (3, DType::F16) => {
396 let res = dequant_for_dtype!(
397 self,
398 w = i32,
399 sz = f16,
400 F16,
401 pack = 10,
402 dev,
403 three_bit,
404 3bit_32_kernel_f16
405 );
406 res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
407 }
408 (3, DType::BF16) => {
409 let res = dequant_for_dtype!(
410 self,
411 w = i32,
412 sz = bf16,
413 BF16,
414 pack = 10,
415 dev,
416 three_bit,
417 3bit_32_kernel_bf16
418 );
419 res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
420 }
421
422 (2, DType::F32) => {
424 dequant_for_dtype!(
425 self,
426 w = u8,
427 sz = f32,
428 F32,
429 pack = 4,
430 dev,
431 two_bit,
432 2bit_u8_kernel_f32
433 )
434 }
435 (2, DType::F16) => {
436 dequant_for_dtype!(
437 self,
438 w = u8,
439 sz = f16,
440 F16,
441 pack = 4,
442 dev,
443 two_bit,
444 2bit_u8_kernel_f16
445 )
446 }
447 (2, DType::BF16) => {
448 dequant_for_dtype!(
449 self,
450 w = u8,
451 sz = bf16,
452 BF16,
453 pack = 4,
454 dev,
455 two_bit,
456 2bit_u8_kernel_bf16
457 )
458 }
459
460 (1, DType::F32) => {
462 dequant_for_dtype!(
463 self,
464 w = u8,
465 sz = f32,
466 F32,
467 pack = 8,
468 dev,
469 one_bit,
470 1bit_u8_kernel_f32
471 )
472 }
473 (1, DType::F16) => {
474 dequant_for_dtype!(
475 self,
476 w = u8,
477 sz = f16,
478 F16,
479 pack = 8,
480 dev,
481 one_bit,
482 1bit_u8_kernel_f16
483 )
484 }
485 (1, DType::BF16) => {
486 dequant_for_dtype!(
487 self,
488 w = u8,
489 sz = bf16,
490 BF16,
491 pack = 8,
492 dev,
493 one_bit,
494 1bit_u8_kernel_bf16
495 )
496 }
497 (bits, dtype) => candle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}"),
498 };
499 inner.reshape(&self.w_shape)
500 }
501
502 fn dequantize_matmul(&self, xs: &Tensor) -> Result<Tensor> {
503 let w = self.dequantize()?;
504 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
506 w,
507 self.bias.clone(),
508 )))?;
509 unquant.forward(xs)
510 }
511
512 pub fn with_bias(mut self, bias: Tensor) -> Self {
513 self.bias = Some(bias);
514 self
515 }
516}
517
518impl QuantMethod for HqqLayer {
519 fn new(method: QuantMethodConfig) -> Result<Self>
520 where
521 Self: Sized,
522 {
523 match method {
524 QuantMethodConfig::Gguf { .. }
525 | QuantMethodConfig::Unquantized(_)
526 | QuantMethodConfig::Gptq { .. }
527 | QuantMethodConfig::Dummy
528 | QuantMethodConfig::FP8 { .. }
529 | QuantMethodConfig::Bnb { .. }
530 | QuantMethodConfig::BlockwiseFP8 { .. }
531 | QuantMethodConfig::Afq { .. } => {
532 unreachable!()
533 }
534 QuantMethodConfig::Hqq {
535 tensor,
536 bits,
537 group_size,
538 axis,
539 optimization_steps,
540 round_zeros,
541 channel_wise,
542 bias,
543 } => {
544 let cfg = HqqConfig {
545 bits,
546 group_size,
547 axis,
548 optimization_steps,
549 round_zeros: round_zeros.unwrap_or(false),
550 channel_wise: channel_wise.unwrap_or(true),
551 };
552
553 let this = Self::quantize(&tensor, tensor.device(), cfg)?;
554 if let Some(bias) = bias {
555 Ok(this.with_bias(bias))
556 } else {
557 Ok(this)
558 }
559 }
560 }
561 }
562
563 fn dequantize_w(&self) -> Result<Tensor> {
564 self.dequantize()
565 }
566
567 fn forward(&self, a: &Tensor) -> Result<Tensor> {
568 self.dequantize_matmul(a)
575 }
576
577 fn quantized_act_type(&self) -> Option<DType> {
578 Some(self.scales.dtype())
579 }
580
581 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
582 candle_core::bail!("HQQ quantization does not support adding weight delta.")
583 }
584
585 fn dtype_and_device(&self) -> (DType, Device) {
586 (self.scales.dtype(), self.scales.device().clone())
587 }
588
589 fn apply_isq(
590 self: Arc<Self>,
591 dtype: Option<IsqType>,
592 device: Device,
593 n_quantized: &AtomicUsize,
594 imatrix_weight: Option<Vec<f32>>,
595 guard: QuantizeOntoGuard,
596 ) -> Result<Arc<dyn QuantMethod>> {
597 let _acquired_quantize_guard = guard.acquire();
598 if imatrix_weight.is_some() {
599 candle_core::bail!("HQQ does not support imatrix.");
601 }
602
603 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
604 let bits = match dtype {
605 Some(IsqType::HQQ8) => HqqBits::Eight,
606 Some(IsqType::HQQ4) => HqqBits::Four,
607 _ => candle_core::bail!("Expected a HQQ ISQ type."),
611 };
612 let cfg = HqqConfig {
613 bits,
614 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
615 axis: HqqAxis::Zero,
616 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
617 round_zeros: false,
618 channel_wise: true,
619 };
620 let dequant = self.dequantize()?;
621 let res = Self::quantize(&dequant, &device, cfg)?;
622 if let Some(ref bias) = self.bias {
623 let bias = bias
624 .to_device(&device)?
625 .to_dtype(res.dtype_and_device().0)?;
626 Ok(Arc::new(res.with_bias(bias)))
627 } else {
628 Ok(Arc::new(res))
629 }
630 }
631}
632
633impl QuantizedSerde for HqqLayer {
670 fn isq_serde_supported(&self) -> bool {
671 true
672 }
673 fn name(&self) -> &'static str {
674 "hqq"
675 }
676 fn serialize(&self) -> Result<Cow<[u8]>> {
677 self.serialize_with_bias(self.bias.clone())
678 }
679 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
680 let mut buffer = Vec::new();
681
682 buffer.extend(&UQFF_VERSION.to_le_bytes());
684
685 buffer.push(QuantizedSerdeType::Hqq as u8);
687
688 buffer.push(bias.is_some() as u8);
690
691 serialize_tensor(&mut buffer, &self.w_q)?;
692 serialize_tensor(&mut buffer, &self.scales)?;
693 serialize_tensor(&mut buffer, &self.zeros)?;
694
695 let w_shape = self.w_shape.dims();
696 buffer.extend((w_shape.len() as u32).to_le_bytes());
697 for dim in w_shape {
698 buffer.extend((*dim as u32).to_le_bytes());
699 }
700
701 buffer.push(self.cfg.bits as u8);
703 buffer.extend(
704 &(<NonZeroUsize as Into<usize>>::into(self.cfg.group_size) as u32).to_le_bytes(),
705 );
706 buffer.push(self.cfg.axis as u8);
707 buffer.extend(&(self.cfg.optimization_steps.unwrap_or(0) as u32).to_le_bytes());
709 buffer.push(self.cfg.round_zeros as u8);
710 buffer.push(self.cfg.channel_wise as u8);
711
712 if let Some(bias) = &bias {
713 serialize_tensor(&mut buffer, bias)?;
715 }
716
717 Ok(Cow::from(buffer))
718 }
719
720 fn deserialize(
721 data: Cow<[u8]>,
722 device: &Device,
723 _comm: &Arc<crate::Comm>,
724 guard: QuantizeOntoGuard,
725 ) -> Result<Arc<dyn QuantMethod>>
726 where
727 Self: Sized,
728 {
729 let mut buffer = Cursor::new(data);
730
731 let version = buffer.read_u32::<LittleEndian>()?;
732 if let Err(e) = version_is_compatible(version) {
733 return Err(candle_core::Error::wrap(e));
734 }
735
736 let isq_type = buffer.read_u8()? as usize;
737 if isq_type != QuantizedSerdeType::Hqq as usize {
738 candle_core::bail!(
739 "ISQ type ({isq_type}) doesn't match expected type {}",
740 QuantizedSerdeType::Hqq as usize
741 );
742 }
743
744 let has_bias = buffer.read_u8()? != 0;
745
746 let _acquired_load_guard = guard.acquire();
747 let w_q = deserialize_tensor(&mut buffer, device)?;
748 let scales = deserialize_tensor(&mut buffer, device)?;
749 let zeros = deserialize_tensor(&mut buffer, device)?;
750
751 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
752
753 let mut dims = Vec::with_capacity(n_dims);
754 for _ in 0..n_dims {
755 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
756 }
757 let w_shape = Shape::from_dims(&dims);
758
759 let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
761 let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
762 let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
763 let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
764 0 => None,
765 other => Some(other),
766 };
767 let round_zeros = buffer.read_u8()? != 0;
768 let channel_wise = buffer.read_u8()? != 0;
769
770 let cfg = HqqConfig {
771 bits,
772 group_size,
773 axis,
774 optimization_steps,
775 round_zeros,
776 channel_wise,
777 };
778
779 let b = if has_bias {
780 Some(deserialize_tensor(&mut buffer, device)?)
781 } else {
782 None
783 };
784
785 Ok(Arc::new(Self {
786 w_q,
787 zeros,
788 scales,
789 bias: b,
790 w_shape,
791 cfg,
792 }))
793 }
794 fn deserialize_ext_bias(
795 data: Cow<[u8]>,
796 device: &Device,
797 guard: QuantizeOntoGuard,
798 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
799 where
800 Self: Sized,
801 {
802 let mut buffer = Cursor::new(data);
803
804 let version = buffer.read_u32::<LittleEndian>()?;
805 if let Err(e) = version_is_compatible(version) {
806 return Err(candle_core::Error::wrap(e));
807 }
808
809 let isq_type = buffer.read_u8()? as usize;
810 if isq_type != QuantizedSerdeType::Hqq as usize {
811 candle_core::bail!(
812 "ISQ type ({isq_type}) doesn't match expected type {}",
813 QuantizedSerdeType::Hqq as usize
814 );
815 }
816
817 let has_bias = buffer.read_u8()? != 0;
818
819 let _acquired_load_guard = guard.acquire();
820 let w_q = deserialize_tensor(&mut buffer, device)?;
821 let scales = deserialize_tensor(&mut buffer, device)?;
822 let zeros = deserialize_tensor(&mut buffer, device)?;
823
824 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
825
826 let mut dims = Vec::with_capacity(n_dims);
827 for _ in 0..n_dims {
828 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
829 }
830 let w_shape = Shape::from_dims(&dims);
831
832 let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
834 let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
835 let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
836 let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
837 0 => None,
838 other => Some(other),
839 };
840 let round_zeros = buffer.read_u8()? != 0;
841 let channel_wise = buffer.read_u8()? != 0;
842
843 let cfg = HqqConfig {
844 bits,
845 group_size,
846 axis,
847 optimization_steps,
848 round_zeros,
849 channel_wise,
850 };
851
852 let b = if has_bias {
853 Some(deserialize_tensor(&mut buffer, device)?)
854 } else {
855 None
856 };
857
858 Ok((
859 Arc::new(Self {
860 w_q,
861 zeros,
862 scales,
863 bias: None,
864 w_shape,
865 cfg,
866 }),
867 b,
868 ))
869 }
870}
871
872impl HqqLayer {
873 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
874 let mut buffer = Cursor::new(data);
875
876 let version = buffer.read_u32::<LittleEndian>()?;
877 if let Err(e) = version_is_compatible(version) {
878 return Err(candle_core::Error::wrap(e));
879 }
880
881 let isq_type = buffer.read_u8()? as usize;
882 if isq_type != QuantizedSerdeType::Hqq as usize {
883 candle_core::bail!(
884 "ISQ type ({isq_type}) doesn't match expected type {}",
885 QuantizedSerdeType::Hqq as usize
886 );
887 }
888
889 let _has_bias = buffer.read_u8()? != 0;
890
891 fake_deserialize_tensor(&mut buffer)?;
892 fake_deserialize_tensor(&mut buffer)?;
893 fake_deserialize_tensor(&mut buffer)?;
894
895 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
896
897 let mut dims = Vec::with_capacity(n_dims);
898 for _ in 0..n_dims {
899 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
900 }
901 let _w_shape = Shape::from_dims(&dims);
902
903 let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
905
906 match bits {
907 HqqBits::Eight => Ok(IsqType::HQQ8),
908 HqqBits::Four => Ok(IsqType::HQQ4),
909 HqqBits::One | HqqBits::Two | HqqBits::Three => {
910 candle_core::bail!("cannot convert hqq bits to isq type")
911 }
912 }
913 }
914}