mistralrs_quant/hqq/
mod.rs

1use byteorder::{LittleEndian, ReadBytesExt};
2use candle_core::{DType, Device, Result, Shape, Tensor};
3
4#[cfg(feature = "cuda")]
5use candle_core::{
6    cuda::{cudarc::driver::DevicePtr, CudaStorageSlice, WrapErr},
7    from_storage_no_op, CudaStorage, Storage,
8};
9
10use candle_nn::Linear;
11#[cfg(feature = "cuda")]
12use half::{bf16, f16};
13use std::{
14    borrow::Cow,
15    io::Cursor,
16    num::NonZeroUsize,
17    sync::{atomic::AtomicUsize, Arc},
18};
19
20use crate::{
21    utils::{
22        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
23        BitWiseOp, LeftshiftOp, UQFF_VERSION,
24    },
25    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
26    UnquantLinear,
27};
28
29#[cfg(feature = "cuda")]
30use crate::utils::{get_cuda_device, get_cuda_slice};
31
32#[cfg(feature = "cuda")]
33use ffi::{eight_bit, four_bit, one_bit, three_bit, two_bit};
34
35#[cfg(feature = "cuda")]
36mod ffi;
37
38#[cfg(not(feature = "cuda"))]
39mod hqq_op;
40
41mod optimize;
42mod quantize;
43
44pub(crate) const ISQ_HQQ_GROUP_SIZE: usize = 64;
45pub(crate) const ISQ_HQQ_DEFAULT_OPT_STEPS: Option<usize> = Some(10);
46pub(crate) const OPTIMIZER_HQQ_DEFAULT_STEPS: usize = 20;
47
48#[cfg(feature = "cuda")]
49macro_rules! dequant_for_dtype {
50    ($this:expr, w=$wq_t:ty, sz=$scale_t:ty, $dtype:ident, pack=$pack:expr, $dev:expr, $bit_thing:ident, $postfix:tt) => {{
51        paste::paste! {
52            let w_slice = get_cuda_slice::<$wq_t>(&$this.w_q)?;
53            let scale_slice = get_cuda_slice::<$scale_t>(&$this.scales)?;
54            let zero_slice = get_cuda_slice::<$scale_t>(&$this.zeros)?;
55
56            let (h, w) = $this.w_q.dims2()?;
57            let num_packed_elems = $pack;
58            let out_shape = Shape::from_dims(&[num_packed_elems * h, w]);
59
60            let out = unsafe { $dev.alloc::<$scale_t>(out_shape.elem_count()).w()? };
61            let out_ptr = *out.device_ptr() as *mut $scale_t;
62            unsafe {
63                $bit_thing::[< dequantize_ $postfix >](
64                    w_slice,
65                    scale_slice,
66                    zero_slice,
67                    out_ptr,
68                    h as i32,
69                    w as i32,
70                );
71            }
72
73            let storage = CudaStorage {
74                slice: CudaStorageSlice::$dtype(out),
75                device: $dev.clone(),
76            };
77            let storage = Storage::Cuda(storage);
78
79            from_storage_no_op(storage, out_shape, false)
80        }
81    }};
82}
83
84#[derive(Debug, Clone, Copy)]
85pub enum HqqAxis {
86    Zero = 0,
87    One = 1,
88}
89
90impl TryFrom<usize> for HqqAxis {
91    type Error = candle_core::Error;
92    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
93        match value {
94            0 => Ok(Self::Zero),
95            1 => Ok(Self::One),
96            other => candle_core::bail!("Unexpected value for HQQ axis {other}"),
97        }
98    }
99}
100
101#[derive(Debug, Clone, Copy)]
102pub enum HqqBits {
103    Eight = 8,
104    Four = 4,
105    Three = 3,
106    Two = 2,
107    One = 1,
108}
109
110impl TryFrom<usize> for HqqBits {
111    type Error = candle_core::Error;
112    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
113        match value {
114            8 => Ok(Self::Eight),
115            4 => Ok(Self::Four),
116            3 => Ok(Self::Three),
117            2 => Ok(Self::Two),
118            1 => Ok(Self::One),
119            other => candle_core::bail!("Unexpected value for HQQ bits {other}"),
120        }
121    }
122}
123
124impl HqqBits {
125    // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/core/bitpack.py#L10
126    pub(crate) fn bitpack_type(&self) -> impl Fn(Tensor) -> Result<Tensor> {
127        match self {
128            Self::Eight => |wq: Tensor| wq.to_dtype(DType::U8),
129            Self::Four => |wq: Tensor| {
130                let wq = wq.to_dtype(DType::U8)?;
131                let step = (wq.dims()[0] as f64 / 2.) as usize;
132
133                let a = wq.narrow(0, 0, step)?;
134                let b = wq.narrow(0, step, step)?;
135                a.leftshift(4)?.bitwise_or(&b)
136            },
137            Self::Two => |wq: Tensor| {
138                let wq = wq.to_dtype(DType::U8)?;
139                let step = (wq.dims()[0] as f64 / 4.) as usize;
140
141                let a = wq.narrow(0, 0, step)?;
142                let b = wq.narrow(0, step, step)?;
143                let c = wq.narrow(0, step * 2, step)?;
144                let d = wq.narrow(0, step * 3, step)?;
145
146                a.leftshift(6)?
147                    .bitwise_or(&b.leftshift(4)?)?
148                    .bitwise_or(&c.leftshift(2)?)?
149                    .bitwise_or(&d)
150            },
151            Self::Three => |wq_in: Tensor| {
152                let wq = Tensor::zeros(
153                    (
154                        (10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize,
155                        wq_in.dims()[1],
156                    ),
157                    DType::U32,
158                    wq_in.device(),
159                )?;
160                let wq = wq
161                    .slice_assign(&[&(..wq_in.dims()[0]), &..], &wq_in.to_dtype(DType::U32)?)?
162                    .to_dtype(DType::I32)?;
163                let step = (wq.dims()[0] as f64 / 10.) as usize;
164
165                let a = wq.narrow(0, 0, step)?;
166                let b = wq.narrow(0, step, step)?;
167                let c = wq.narrow(0, step * 2, step)?;
168                let d = wq.narrow(0, step * 3, step)?;
169                let e = wq.narrow(0, step * 4, step)?;
170                let f = wq.narrow(0, step * 5, step)?;
171                let g = wq.narrow(0, step * 6, step)?;
172                let h = wq.narrow(0, step * 7, step)?;
173                let i = wq.narrow(0, step * 8, step)?;
174                let j = wq.narrow(0, step * 9, step)?;
175
176                a.leftshift(27)
177                    .unwrap()
178                    .bitwise_or(&b.leftshift(24).unwrap())
179                    .unwrap()
180                    .bitwise_or(&c.leftshift(21)?)?
181                    .bitwise_or(&d.leftshift(18)?)?
182                    .bitwise_or(&e.leftshift(15)?)?
183                    .bitwise_or(&f.leftshift(12)?)?
184                    .bitwise_or(&g.leftshift(9)?)?
185                    .bitwise_or(&h.leftshift(6)?)?
186                    .bitwise_or(&i.leftshift(3)?)?
187                    .bitwise_or(&j)
188            },
189            Self::One => |wq: Tensor| {
190                let wq = wq.to_dtype(DType::U8)?;
191                let step = (wq.dims()[0] as f64 / 8.) as usize;
192
193                let a = wq.narrow(0, 0, step)?;
194                let b = wq.narrow(0, step, step)?;
195                let c = wq.narrow(0, step * 2, step)?;
196                let d = wq.narrow(0, step * 3, step)?;
197                let e = wq.narrow(0, step * 4, step)?;
198                let f = wq.narrow(0, step * 5, step)?;
199                let g = wq.narrow(0, step * 6, step)?;
200                let h = wq.narrow(0, step * 7, step)?;
201
202                a.leftshift(7)?
203                    .bitwise_or(&b.leftshift(6)?)?
204                    .bitwise_or(&c.leftshift(5)?)?
205                    .bitwise_or(&d.leftshift(4)?)?
206                    .bitwise_or(&e.leftshift(3)?)?
207                    .bitwise_or(&f.leftshift(2)?)?
208                    .bitwise_or(&g.leftshift(1)?)?
209                    .bitwise_or(&h)
210            },
211        }
212    }
213}
214
215#[derive(Debug, Clone, Copy)]
216pub struct HqqConfig {
217    pub bits: HqqBits,
218    pub group_size: NonZeroUsize,
219    pub axis: HqqAxis,
220    pub optimization_steps: Option<usize>,
221    pub round_zeros: bool,  // default false
222    pub channel_wise: bool, // default true
223}
224
225#[derive(Debug)]
226pub struct HqqLayer {
227    pub(crate) w_q: Tensor,
228    pub(crate) zeros: Tensor,
229    pub(crate) scales: Tensor,
230    pub(crate) bias: Option<Tensor>,
231    pub(crate) w_shape: Shape,
232    pub(crate) cfg: HqqConfig,
233}
234
235impl HqqLayer {
236    /// Dequantize `self` into a tensor of shape `scales` or `zeros`.
237    #[cfg(not(feature = "cuda"))]
238    fn dequantize(&self) -> Result<Tensor> {
239        use crate::hqq::hqq_op::{Dequant1Bit, Dequant2Bit, Dequant3Bit, Dequant4Bit, Dequant8Bit};
240
241        match (self.scales.dtype(), self.zeros.dtype()) {
242            (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
243            (a, b) => {
244                candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
245            }
246        }
247        if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
248        {
249            candle_core::bail!("All tensors must be contiguous!");
250        }
251        if self.cfg.axis as usize != 0 {
252            candle_core::bail!(
253                "CPU HQQ dequantization requires axis == 0, got {}.",
254                self.cfg.axis as usize
255            );
256        }
257        let (h, w) = self.w_q.dims2()?;
258
259        match self.cfg.bits as usize {
260            8 => self
261                .w_q
262                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant8Bit { h, w })?
263                .reshape(&self.w_shape),
264            4 => self
265                .w_q
266                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant4Bit { h, w })?
267                .reshape(&self.w_shape),
268            3 => self
269                .w_q
270                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant3Bit { h, w })?
271                .reshape(&self.w_shape),
272            2 => self
273                .w_q
274                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant2Bit { h, w })?
275                .reshape(&self.w_shape),
276            1 => self
277                .w_q
278                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant1Bit { h, w })?
279                .reshape(&self.w_shape),
280            b => candle_core::bail!("Unreachable bits {b}"),
281        }
282    }
283
284    /// Dequantize `self` into a tensor of shape `scales` or `zeros`.
285    #[cfg(feature = "cuda")]
286    fn dequantize(&self) -> Result<Tensor> {
287        match (self.scales.dtype(), self.zeros.dtype()) {
288            (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
289            (a, b) => {
290                candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
291            }
292        }
293        if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
294        {
295            candle_core::bail!("All tensors must be contiguous!");
296        }
297        if self.cfg.axis as usize != 0 {
298            candle_core::bail!(
299                "CUDA HQQ dequantization requires axis == 0, got {}.",
300                self.cfg.axis as usize
301            );
302        }
303        let dev = get_cuda_device(&self.w_q)?;
304
305        let inner = match (self.cfg.bits as usize, self.scales.dtype()) {
306            // 8 bits
307            (8, DType::F32) => {
308                dequant_for_dtype!(
309                    self,
310                    w = u8,
311                    sz = f32,
312                    F32,
313                    pack = 1,
314                    dev,
315                    eight_bit,
316                    8bit_u8_kernel_f32
317                )
318            }
319            (8, DType::F16) => {
320                dequant_for_dtype!(
321                    self,
322                    w = u8,
323                    sz = f16,
324                    F16,
325                    pack = 1,
326                    dev,
327                    eight_bit,
328                    8bit_u8_kernel_f16
329                )
330            }
331            (8, DType::BF16) => {
332                dequant_for_dtype!(
333                    self,
334                    w = u8,
335                    sz = bf16,
336                    BF16,
337                    pack = 1,
338                    dev,
339                    eight_bit,
340                    8bit_u8_kernel_bf16
341                )
342            }
343
344            // 4 bits
345            (4, DType::F32) => {
346                dequant_for_dtype!(
347                    self,
348                    w = u8,
349                    sz = f32,
350                    F32,
351                    pack = 2,
352                    dev,
353                    four_bit,
354                    4bit_u8_kernel_f32
355                )
356            }
357            (4, DType::F16) => {
358                dequant_for_dtype!(
359                    self,
360                    w = u8,
361                    sz = f16,
362                    F16,
363                    pack = 2,
364                    dev,
365                    four_bit,
366                    4bit_u8_kernel_f16
367                )
368            }
369            (4, DType::BF16) => {
370                dequant_for_dtype!(
371                    self,
372                    w = u8,
373                    sz = bf16,
374                    BF16,
375                    pack = 2,
376                    dev,
377                    four_bit,
378                    4bit_u8_kernel_bf16
379                )
380            }
381
382            // 3 bits
383            // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/kernels/hqq_aten_cuda.cpp#L42-L45
384            (3, DType::F32) => {
385                let res = dequant_for_dtype!(
386                    self,
387                    w = i32,
388                    sz = f32,
389                    F32,
390                    pack = 10,
391                    dev,
392                    three_bit,
393                    3bit_32_kernel_f32
394                );
395                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
396            }
397            (3, DType::F16) => {
398                let res = dequant_for_dtype!(
399                    self,
400                    w = i32,
401                    sz = f16,
402                    F16,
403                    pack = 10,
404                    dev,
405                    three_bit,
406                    3bit_32_kernel_f16
407                );
408                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
409            }
410            (3, DType::BF16) => {
411                let res = dequant_for_dtype!(
412                    self,
413                    w = i32,
414                    sz = bf16,
415                    BF16,
416                    pack = 10,
417                    dev,
418                    three_bit,
419                    3bit_32_kernel_bf16
420                );
421                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
422            }
423
424            // 2 bits
425            (2, DType::F32) => {
426                dequant_for_dtype!(
427                    self,
428                    w = u8,
429                    sz = f32,
430                    F32,
431                    pack = 4,
432                    dev,
433                    two_bit,
434                    2bit_u8_kernel_f32
435                )
436            }
437            (2, DType::F16) => {
438                dequant_for_dtype!(
439                    self,
440                    w = u8,
441                    sz = f16,
442                    F16,
443                    pack = 4,
444                    dev,
445                    two_bit,
446                    2bit_u8_kernel_f16
447                )
448            }
449            (2, DType::BF16) => {
450                dequant_for_dtype!(
451                    self,
452                    w = u8,
453                    sz = bf16,
454                    BF16,
455                    pack = 4,
456                    dev,
457                    two_bit,
458                    2bit_u8_kernel_bf16
459                )
460            }
461
462            // 1 bit
463            (1, DType::F32) => {
464                dequant_for_dtype!(
465                    self,
466                    w = u8,
467                    sz = f32,
468                    F32,
469                    pack = 8,
470                    dev,
471                    one_bit,
472                    1bit_u8_kernel_f32
473                )
474            }
475            (1, DType::F16) => {
476                dequant_for_dtype!(
477                    self,
478                    w = u8,
479                    sz = f16,
480                    F16,
481                    pack = 8,
482                    dev,
483                    one_bit,
484                    1bit_u8_kernel_f16
485                )
486            }
487            (1, DType::BF16) => {
488                dequant_for_dtype!(
489                    self,
490                    w = u8,
491                    sz = bf16,
492                    BF16,
493                    pack = 8,
494                    dev,
495                    one_bit,
496                    1bit_u8_kernel_bf16
497                )
498            }
499            (bits, dtype) => candle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}"),
500        };
501        inner.reshape(&self.w_shape)
502    }
503
504    fn dequantize_matmul(&self, xs: &Tensor) -> Result<Tensor> {
505        let w = self.dequantize()?;
506        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
507        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
508            w,
509            self.bias.clone(),
510        )))?;
511        unquant.forward(xs)
512    }
513
514    pub fn with_bias(mut self, bias: Tensor) -> Self {
515        self.bias = Some(bias);
516        self
517    }
518}
519
520impl QuantMethod for HqqLayer {
521    fn new(method: QuantMethodConfig) -> Result<Self>
522    where
523        Self: Sized,
524    {
525        match method {
526            QuantMethodConfig::Gguf { .. }
527            | QuantMethodConfig::Unquantized(_)
528            | QuantMethodConfig::GptqAwq { .. }
529            | QuantMethodConfig::Dummy
530            | QuantMethodConfig::FP8 { .. }
531            | QuantMethodConfig::Bnb { .. }
532            | QuantMethodConfig::BlockwiseFP8 { .. }
533            | QuantMethodConfig::Afq { .. } => {
534                unreachable!()
535            }
536            QuantMethodConfig::Hqq {
537                tensor,
538                bits,
539                group_size,
540                axis,
541                optimization_steps,
542                round_zeros,
543                channel_wise,
544                bias,
545            } => {
546                let cfg = HqqConfig {
547                    bits,
548                    group_size,
549                    axis,
550                    optimization_steps,
551                    round_zeros: round_zeros.unwrap_or(false),
552                    channel_wise: channel_wise.unwrap_or(true),
553                };
554
555                let this = Self::quantize(&tensor, tensor.device(), cfg)?;
556                if let Some(bias) = bias {
557                    Ok(this.with_bias(bias))
558                } else {
559                    Ok(this)
560                }
561            }
562        }
563    }
564
565    fn dequantize_w(&self) -> Result<Tensor> {
566        self.dequantize()
567    }
568
569    fn forward(&self, a: &Tensor) -> Result<Tensor> {
570        /*
571        if self.cfg.force_dequantize {
572            self.dequantize_matmul(a)
573        } else {
574            todo!()
575        } */
576        self.dequantize_matmul(a)
577    }
578
579    fn quantized_act_type(&self) -> Option<DType> {
580        Some(self.scales.dtype())
581    }
582
583    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
584        candle_core::bail!("HQQ quantization does not support adding weight delta.")
585    }
586
587    fn dtype_and_device(&self) -> (DType, Device) {
588        (self.scales.dtype(), self.scales.device().clone())
589    }
590
591    fn apply_isq(
592        self: Arc<Self>,
593        dtype: Option<IsqType>,
594        device: Device,
595        n_quantized: &AtomicUsize,
596        imatrix_weight: Option<Vec<f32>>,
597        guard: QuantizeOntoGuard,
598    ) -> Result<Arc<dyn QuantMethod>> {
599        let _acquired_quantize_guard = guard.acquire(&device);
600        if imatrix_weight.is_some() {
601            // TODO just warn?
602            candle_core::bail!("HQQ does not support imatrix.");
603        }
604
605        n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
606        let bits = match dtype {
607            Some(IsqType::HQQ8) => HqqBits::Eight,
608            Some(IsqType::HQQ4) => HqqBits::Four,
609            // Some(IsqType::HQQ3) => HqqBits::Three,
610            // Some(IsqType::HQQ2) => HqqBits::Two,
611            // Some(IsqType::HQQ1) => HqqBits::One,
612            _ => candle_core::bail!("Expected a HQQ ISQ type."),
613        };
614        let cfg = HqqConfig {
615            bits,
616            group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
617            axis: HqqAxis::Zero,
618            optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
619            round_zeros: false,
620            channel_wise: true,
621        };
622        let dequant = self.dequantize()?;
623        let res = Self::quantize(&dequant, &device, cfg)?;
624        if let Some(ref bias) = self.bias {
625            let bias = bias
626                .to_device(&device)?
627                .to_dtype(res.dtype_and_device().0)?;
628            Ok(Arc::new(res.with_bias(bias)))
629        } else {
630            Ok(Arc::new(res))
631        }
632    }
633}
634
635// Serialization structure:
636//
637// -----------------------
638// UQFF version, u32, little endian
639// -----------------------
640// ISQ type (2 for hqq), u8, little endian
641// -----------------------
642// Whether bias data is included, u8 boolean
643// -----------------------
644// Quantized weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
645// -----------------------
646// Quantized scale tensor data generated by `serialize_tensor`. Refer to its docs for layout.
647// -----------------------
648// Quantized zeroes tensor data generated by `serialize_tensor`. Refer to its docs for layout.
649// -----------------------
650// Weight (after dequant) shape dims, u32, little endian
651// -----------------------
652// ...
653// Array (in original order): Weight (after dequant) shape dims, u32, little endian
654// ...
655// -----------------------
656// Cfg bits, u8, little endian
657// -----------------------
658// Cfg group size, u32, little endian
659// -----------------------
660// Cfg axis, u8, little endian
661// -----------------------
662// Cfg optimization steps, u32, little endian
663// -----------------------
664// Cfg round_zeros, boolean u8, little endian
665// -----------------------
666// Cfg channel_wise, boolean u8, little endian
667// -----------------------
668// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
669// -----------------------
670
671impl QuantizedSerde for HqqLayer {
672    fn isq_serde_supported(&self) -> bool {
673        true
674    }
675    fn name(&self) -> &'static str {
676        "hqq"
677    }
678    fn serialize(&self) -> Result<Cow<[u8]>> {
679        self.serialize_with_bias(self.bias.clone())
680    }
681    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
682        let mut buffer = Vec::new();
683
684        // Version is always first!
685        buffer.extend(&UQFF_VERSION.to_le_bytes());
686
687        // ISQ type for hqq is 2
688        buffer.push(QuantizedSerdeType::Hqq as u8);
689
690        // Has bias
691        buffer.push(bias.is_some() as u8);
692
693        serialize_tensor(&mut buffer, &self.w_q)?;
694        serialize_tensor(&mut buffer, &self.scales)?;
695        serialize_tensor(&mut buffer, &self.zeros)?;
696
697        let w_shape = self.w_shape.dims();
698        buffer.extend((w_shape.len() as u32).to_le_bytes());
699        for dim in w_shape {
700            buffer.extend((*dim as u32).to_le_bytes());
701        }
702
703        // Config
704        buffer.push(self.cfg.bits as u8);
705        buffer.extend(
706            &(<NonZeroUsize as Into<usize>>::into(self.cfg.group_size) as u32).to_le_bytes(),
707        );
708        buffer.push(self.cfg.axis as u8);
709        // FIXME: using 0 as a sentinel for None is OK because it really should be.
710        buffer.extend(&(self.cfg.optimization_steps.unwrap_or(0) as u32).to_le_bytes());
711        buffer.push(self.cfg.round_zeros as u8);
712        buffer.push(self.cfg.channel_wise as u8);
713
714        if let Some(bias) = &bias {
715            // Bias
716            serialize_tensor(&mut buffer, bias)?;
717        }
718
719        Ok(Cow::from(buffer))
720    }
721
722    fn deserialize(
723        data: Cow<[u8]>,
724        device: &Device,
725        _comm: &Arc<crate::Comm>,
726        guard: QuantizeOntoGuard,
727    ) -> Result<Arc<dyn QuantMethod>>
728    where
729        Self: Sized,
730    {
731        let mut buffer = Cursor::new(data);
732
733        let version = buffer.read_u32::<LittleEndian>()?;
734        if let Err(e) = version_is_compatible(version) {
735            return Err(candle_core::Error::wrap(e));
736        }
737
738        let isq_type = buffer.read_u8()? as usize;
739        if isq_type != QuantizedSerdeType::Hqq as usize {
740            candle_core::bail!(
741                "ISQ type ({isq_type}) doesn't match expected type {}",
742                QuantizedSerdeType::Hqq as usize
743            );
744        }
745
746        let has_bias = buffer.read_u8()? != 0;
747
748        let _acquired_load_guard = guard.acquire(device);
749        let w_q = deserialize_tensor(&mut buffer, device)?;
750        let scales = deserialize_tensor(&mut buffer, device)?;
751        let zeros = deserialize_tensor(&mut buffer, device)?;
752
753        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
754
755        let mut dims = Vec::with_capacity(n_dims);
756        for _ in 0..n_dims {
757            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
758        }
759        let w_shape = Shape::from_dims(&dims);
760
761        // TODO: keep this in sync with get_isq_type_from_uqff!
762        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
763        let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
764        let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
765        let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
766            0 => None,
767            other => Some(other),
768        };
769        let round_zeros = buffer.read_u8()? != 0;
770        let channel_wise = buffer.read_u8()? != 0;
771
772        let cfg = HqqConfig {
773            bits,
774            group_size,
775            axis,
776            optimization_steps,
777            round_zeros,
778            channel_wise,
779        };
780
781        let b = if has_bias {
782            Some(deserialize_tensor(&mut buffer, device)?)
783        } else {
784            None
785        };
786
787        Ok(Arc::new(Self {
788            w_q,
789            zeros,
790            scales,
791            bias: b,
792            w_shape,
793            cfg,
794        }))
795    }
796    fn deserialize_ext_bias(
797        data: Cow<[u8]>,
798        device: &Device,
799        guard: QuantizeOntoGuard,
800    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
801    where
802        Self: Sized,
803    {
804        let mut buffer = Cursor::new(data);
805
806        let version = buffer.read_u32::<LittleEndian>()?;
807        if let Err(e) = version_is_compatible(version) {
808            return Err(candle_core::Error::wrap(e));
809        }
810
811        let isq_type = buffer.read_u8()? as usize;
812        if isq_type != QuantizedSerdeType::Hqq as usize {
813            candle_core::bail!(
814                "ISQ type ({isq_type}) doesn't match expected type {}",
815                QuantizedSerdeType::Hqq as usize
816            );
817        }
818
819        let has_bias = buffer.read_u8()? != 0;
820
821        let _acquired_load_guard = guard.acquire(device);
822        let w_q = deserialize_tensor(&mut buffer, device)?;
823        let scales = deserialize_tensor(&mut buffer, device)?;
824        let zeros = deserialize_tensor(&mut buffer, device)?;
825
826        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
827
828        let mut dims = Vec::with_capacity(n_dims);
829        for _ in 0..n_dims {
830            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
831        }
832        let w_shape = Shape::from_dims(&dims);
833
834        // TODO: keep this in sync with get_isq_type_from_uqff!
835        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
836        let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
837        let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
838        let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
839            0 => None,
840            other => Some(other),
841        };
842        let round_zeros = buffer.read_u8()? != 0;
843        let channel_wise = buffer.read_u8()? != 0;
844
845        let cfg = HqqConfig {
846            bits,
847            group_size,
848            axis,
849            optimization_steps,
850            round_zeros,
851            channel_wise,
852        };
853
854        let b = if has_bias {
855            Some(deserialize_tensor(&mut buffer, device)?)
856        } else {
857            None
858        };
859
860        Ok((
861            Arc::new(Self {
862                w_q,
863                zeros,
864                scales,
865                bias: None,
866                w_shape,
867                cfg,
868            }),
869            b,
870        ))
871    }
872}
873
874impl HqqLayer {
875    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
876        let mut buffer = Cursor::new(data);
877
878        let version = buffer.read_u32::<LittleEndian>()?;
879        if let Err(e) = version_is_compatible(version) {
880            return Err(candle_core::Error::wrap(e));
881        }
882
883        let isq_type = buffer.read_u8()? as usize;
884        if isq_type != QuantizedSerdeType::Hqq as usize {
885            candle_core::bail!(
886                "ISQ type ({isq_type}) doesn't match expected type {}",
887                QuantizedSerdeType::Hqq as usize
888            );
889        }
890
891        let _has_bias = buffer.read_u8()? != 0;
892
893        fake_deserialize_tensor(&mut buffer)?;
894        fake_deserialize_tensor(&mut buffer)?;
895        fake_deserialize_tensor(&mut buffer)?;
896
897        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
898
899        let mut dims = Vec::with_capacity(n_dims);
900        for _ in 0..n_dims {
901            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
902        }
903        let _w_shape = Shape::from_dims(&dims);
904
905        // TODO: keep this in sync with get_isq_type_from_uqff!
906        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
907
908        match bits {
909            HqqBits::Eight => Ok(IsqType::HQQ8),
910            HqqBits::Four => Ok(IsqType::HQQ4),
911            HqqBits::One | HqqBits::Two | HqqBits::Three => {
912                candle_core::bail!("cannot convert hqq bits to isq type")
913            }
914        }
915    }
916}