mistralrs_quant/hqq/
mod.rs

1use byteorder::{LittleEndian, ReadBytesExt};
2use candle_core::{DType, Device, Result, Shape, Tensor};
3
4#[cfg(feature = "cuda")]
5use candle_core::{
6    cuda::{cudarc::driver::DevicePtr, CudaStorageSlice, WrapErr},
7    from_storage_no_op, CudaStorage, Storage,
8};
9
10use candle_nn::Linear;
11#[cfg(feature = "cuda")]
12use half::{bf16, f16};
13use std::{
14    borrow::Cow,
15    io::Cursor,
16    num::NonZeroUsize,
17    sync::{atomic::AtomicUsize, Arc},
18};
19
20use crate::{
21    utils::{
22        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
23        BitWiseOp, LeftshiftOp, UQFF_VERSION,
24    },
25    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
26    UnquantLinear,
27};
28
29#[cfg(feature = "cuda")]
30use crate::utils::{get_cuda_device, get_cuda_slice};
31
32#[cfg(feature = "cuda")]
33use ffi::{eight_bit, four_bit, one_bit, three_bit, two_bit};
34
35#[cfg(feature = "cuda")]
36mod ffi;
37
38#[cfg(not(feature = "cuda"))]
39mod hqq_op;
40
41mod optimize;
42mod quantize;
43
44pub(crate) const ISQ_HQQ_GROUP_SIZE: usize = 64;
45pub(crate) const ISQ_HQQ_DEFAULT_OPT_STEPS: Option<usize> = Some(10);
46pub(crate) const OPTIMIZER_HQQ_DEFAULT_STEPS: usize = 20;
47
48#[cfg(feature = "cuda")]
49macro_rules! dequant_for_dtype {
50    ($this:expr, w=$wq_t:ty, sz=$scale_t:ty, $dtype:ident, pack=$pack:expr, $dev:expr, $bit_thing:ident, $postfix:tt) => {{
51        paste::paste! {
52            let w_slice = get_cuda_slice::<$wq_t>(&$this.w_q)?;
53            let scale_slice = get_cuda_slice::<$scale_t>(&$this.scales)?;
54            let zero_slice = get_cuda_slice::<$scale_t>(&$this.zeros)?;
55
56            let (h, w) = $this.w_q.dims2()?;
57            let num_packed_elems = $pack;
58            let out_shape = Shape::from_dims(&[num_packed_elems * h, w]);
59
60            let out = unsafe { $dev.alloc::<$scale_t>(out_shape.elem_count()).w()? };
61            let out_ptr = *out.device_ptr() as *mut $scale_t;
62            unsafe {
63                $bit_thing::[< dequantize_ $postfix >](
64                    w_slice,
65                    scale_slice,
66                    zero_slice,
67                    out_ptr,
68                    h as i32,
69                    w as i32,
70                );
71            }
72
73            let storage = CudaStorage {
74                slice: CudaStorageSlice::$dtype(out),
75                device: $dev.clone(),
76            };
77            let storage = Storage::Cuda(storage);
78
79            from_storage_no_op(storage, out_shape, false)
80        }
81    }};
82}
83
84#[derive(Debug, Clone, Copy)]
85pub enum HqqAxis {
86    Zero = 0,
87    One = 1,
88}
89
90impl TryFrom<usize> for HqqAxis {
91    type Error = candle_core::Error;
92    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
93        match value {
94            0 => Ok(Self::Zero),
95            1 => Ok(Self::One),
96            other => candle_core::bail!("Unexpected value for HQQ axis {other}"),
97        }
98    }
99}
100
101#[derive(Debug, Clone, Copy)]
102pub enum HqqBits {
103    Eight = 8,
104    Four = 4,
105    Three = 3,
106    Two = 2,
107    One = 1,
108}
109
110impl TryFrom<usize> for HqqBits {
111    type Error = candle_core::Error;
112    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
113        match value {
114            8 => Ok(Self::Eight),
115            4 => Ok(Self::Four),
116            3 => Ok(Self::Three),
117            2 => Ok(Self::Two),
118            1 => Ok(Self::One),
119            other => candle_core::bail!("Unexpected value for HQQ bits {other}"),
120        }
121    }
122}
123
124impl HqqBits {
125    // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/core/bitpack.py#L10
126    pub(crate) fn bitpack_type(&self) -> impl Fn(Tensor) -> Result<Tensor> {
127        match self {
128            Self::Eight => |wq: Tensor| wq.to_dtype(DType::U8),
129            Self::Four => |wq: Tensor| {
130                let wq = wq.to_dtype(DType::U8)?;
131                let step = (wq.dims()[0] as f64 / 2.) as usize;
132
133                let a = wq.narrow(0, 0, step)?;
134                let b = wq.narrow(0, step, step)?;
135                a.leftshift(4)?.bitwise_or(&b)
136            },
137            Self::Two => |wq: Tensor| {
138                let wq = wq.to_dtype(DType::U8)?;
139                let step = (wq.dims()[0] as f64 / 4.) as usize;
140
141                let a = wq.narrow(0, 0, step)?;
142                let b = wq.narrow(0, step, step)?;
143                let c = wq.narrow(0, step * 2, step)?;
144                let d = wq.narrow(0, step * 3, step)?;
145
146                a.leftshift(6)?
147                    .bitwise_or(&b.leftshift(4)?)?
148                    .bitwise_or(&c.leftshift(2)?)?
149                    .bitwise_or(&d)
150            },
151            Self::Three => |wq_in: Tensor| {
152                let wq = Tensor::zeros(
153                    (
154                        (10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize,
155                        wq_in.dims()[1],
156                    ),
157                    DType::U32,
158                    wq_in.device(),
159                )?;
160                let wq = wq
161                    .slice_assign(&[&(..wq_in.dims()[0]), &..], &wq_in.to_dtype(DType::U32)?)?
162                    .to_dtype(DType::I32)?;
163                let step = (wq.dims()[0] as f64 / 10.) as usize;
164
165                let a = wq.narrow(0, 0, step)?;
166                let b = wq.narrow(0, step, step)?;
167                let c = wq.narrow(0, step * 2, step)?;
168                let d = wq.narrow(0, step * 3, step)?;
169                let e = wq.narrow(0, step * 4, step)?;
170                let f = wq.narrow(0, step * 5, step)?;
171                let g = wq.narrow(0, step * 6, step)?;
172                let h = wq.narrow(0, step * 7, step)?;
173                let i = wq.narrow(0, step * 8, step)?;
174                let j = wq.narrow(0, step * 9, step)?;
175
176                a.leftshift(27)?
177                    .bitwise_or(&b.leftshift(24)?)?
178                    .bitwise_or(&c.leftshift(21)?)?
179                    .bitwise_or(&d.leftshift(18)?)?
180                    .bitwise_or(&e.leftshift(15)?)?
181                    .bitwise_or(&f.leftshift(12)?)?
182                    .bitwise_or(&g.leftshift(9)?)?
183                    .bitwise_or(&h.leftshift(6)?)?
184                    .bitwise_or(&i.leftshift(3)?)?
185                    .bitwise_or(&j)
186            },
187            Self::One => |wq: Tensor| {
188                let wq = wq.to_dtype(DType::U8)?;
189                let step = (wq.dims()[0] as f64 / 8.) as usize;
190
191                let a = wq.narrow(0, 0, step)?;
192                let b = wq.narrow(0, step, step)?;
193                let c = wq.narrow(0, step * 2, step)?;
194                let d = wq.narrow(0, step * 3, step)?;
195                let e = wq.narrow(0, step * 4, step)?;
196                let f = wq.narrow(0, step * 5, step)?;
197                let g = wq.narrow(0, step * 6, step)?;
198                let h = wq.narrow(0, step * 7, step)?;
199
200                a.leftshift(7)?
201                    .bitwise_or(&b.leftshift(6)?)?
202                    .bitwise_or(&c.leftshift(5)?)?
203                    .bitwise_or(&d.leftshift(4)?)?
204                    .bitwise_or(&e.leftshift(3)?)?
205                    .bitwise_or(&f.leftshift(2)?)?
206                    .bitwise_or(&g.leftshift(1)?)?
207                    .bitwise_or(&h)
208            },
209        }
210    }
211}
212
213#[derive(Debug, Clone, Copy)]
214pub struct HqqConfig {
215    pub bits: HqqBits,
216    pub group_size: NonZeroUsize,
217    pub axis: HqqAxis,
218    pub optimization_steps: Option<usize>,
219    pub round_zeros: bool,  // default false
220    pub channel_wise: bool, // default true
221}
222
223#[derive(Debug)]
224pub struct HqqLayer {
225    pub(crate) w_q: Tensor,
226    pub(crate) zeros: Tensor,
227    pub(crate) scales: Tensor,
228    pub(crate) bias: Option<Tensor>,
229    pub(crate) w_shape: Shape,
230    pub(crate) cfg: HqqConfig,
231}
232
233impl HqqLayer {
234    /// Dequantize `self` into a tensor of shape `scales` or `zeros`.
235    #[cfg(not(feature = "cuda"))]
236    fn dequantize(&self) -> Result<Tensor> {
237        use crate::hqq::hqq_op::{Dequant1Bit, Dequant2Bit, Dequant3Bit, Dequant4Bit, Dequant8Bit};
238
239        match (self.scales.dtype(), self.zeros.dtype()) {
240            (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
241            (a, b) => {
242                candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
243            }
244        }
245        if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
246        {
247            candle_core::bail!("All tensors must be contiguous!");
248        }
249        if self.cfg.axis as usize != 0 {
250            candle_core::bail!(
251                "CPU HQQ dequantization requires axis == 0, got {}.",
252                self.cfg.axis as usize
253            );
254        }
255        let (h, w) = self.w_q.dims2()?;
256
257        match self.cfg.bits as usize {
258            8 => self
259                .w_q
260                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant8Bit { h, w })?
261                .reshape(&self.w_shape),
262            4 => self
263                .w_q
264                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant4Bit { h, w })?
265                .reshape(&self.w_shape),
266            3 => self
267                .w_q
268                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant3Bit { h, w })?
269                .reshape(&self.w_shape),
270            2 => self
271                .w_q
272                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant2Bit { h, w })?
273                .reshape(&self.w_shape),
274            1 => self
275                .w_q
276                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant1Bit { h, w })?
277                .reshape(&self.w_shape),
278            b => candle_core::bail!("Unreachable bits {b}"),
279        }
280    }
281
282    /// Dequantize `self` into a tensor of shape `scales` or `zeros`.
283    #[cfg(feature = "cuda")]
284    fn dequantize(&self) -> Result<Tensor> {
285        match (self.scales.dtype(), self.zeros.dtype()) {
286            (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
287            (a, b) => {
288                candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
289            }
290        }
291        if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
292        {
293            candle_core::bail!("All tensors must be contiguous!");
294        }
295        if self.cfg.axis as usize != 0 {
296            candle_core::bail!(
297                "CUDA HQQ dequantization requires axis == 0, got {}.",
298                self.cfg.axis as usize
299            );
300        }
301        let dev = get_cuda_device(&self.w_q)?;
302
303        let inner = match (self.cfg.bits as usize, self.scales.dtype()) {
304            // 8 bits
305            (8, DType::F32) => {
306                dequant_for_dtype!(
307                    self,
308                    w = u8,
309                    sz = f32,
310                    F32,
311                    pack = 1,
312                    dev,
313                    eight_bit,
314                    8bit_u8_kernel_f32
315                )
316            }
317            (8, DType::F16) => {
318                dequant_for_dtype!(
319                    self,
320                    w = u8,
321                    sz = f16,
322                    F16,
323                    pack = 1,
324                    dev,
325                    eight_bit,
326                    8bit_u8_kernel_f16
327                )
328            }
329            (8, DType::BF16) => {
330                dequant_for_dtype!(
331                    self,
332                    w = u8,
333                    sz = bf16,
334                    BF16,
335                    pack = 1,
336                    dev,
337                    eight_bit,
338                    8bit_u8_kernel_bf16
339                )
340            }
341
342            // 4 bits
343            (4, DType::F32) => {
344                dequant_for_dtype!(
345                    self,
346                    w = u8,
347                    sz = f32,
348                    F32,
349                    pack = 2,
350                    dev,
351                    four_bit,
352                    4bit_u8_kernel_f32
353                )
354            }
355            (4, DType::F16) => {
356                dequant_for_dtype!(
357                    self,
358                    w = u8,
359                    sz = f16,
360                    F16,
361                    pack = 2,
362                    dev,
363                    four_bit,
364                    4bit_u8_kernel_f16
365                )
366            }
367            (4, DType::BF16) => {
368                dequant_for_dtype!(
369                    self,
370                    w = u8,
371                    sz = bf16,
372                    BF16,
373                    pack = 2,
374                    dev,
375                    four_bit,
376                    4bit_u8_kernel_bf16
377                )
378            }
379
380            // 3 bits
381            // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/kernels/hqq_aten_cuda.cpp#L42-L45
382            (3, DType::F32) => {
383                let res = dequant_for_dtype!(
384                    self,
385                    w = i32,
386                    sz = f32,
387                    F32,
388                    pack = 10,
389                    dev,
390                    three_bit,
391                    3bit_32_kernel_f32
392                );
393                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
394            }
395            (3, DType::F16) => {
396                let res = dequant_for_dtype!(
397                    self,
398                    w = i32,
399                    sz = f16,
400                    F16,
401                    pack = 10,
402                    dev,
403                    three_bit,
404                    3bit_32_kernel_f16
405                );
406                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
407            }
408            (3, DType::BF16) => {
409                let res = dequant_for_dtype!(
410                    self,
411                    w = i32,
412                    sz = bf16,
413                    BF16,
414                    pack = 10,
415                    dev,
416                    three_bit,
417                    3bit_32_kernel_bf16
418                );
419                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
420            }
421
422            // 2 bits
423            (2, DType::F32) => {
424                dequant_for_dtype!(
425                    self,
426                    w = u8,
427                    sz = f32,
428                    F32,
429                    pack = 4,
430                    dev,
431                    two_bit,
432                    2bit_u8_kernel_f32
433                )
434            }
435            (2, DType::F16) => {
436                dequant_for_dtype!(
437                    self,
438                    w = u8,
439                    sz = f16,
440                    F16,
441                    pack = 4,
442                    dev,
443                    two_bit,
444                    2bit_u8_kernel_f16
445                )
446            }
447            (2, DType::BF16) => {
448                dequant_for_dtype!(
449                    self,
450                    w = u8,
451                    sz = bf16,
452                    BF16,
453                    pack = 4,
454                    dev,
455                    two_bit,
456                    2bit_u8_kernel_bf16
457                )
458            }
459
460            // 1 bit
461            (1, DType::F32) => {
462                dequant_for_dtype!(
463                    self,
464                    w = u8,
465                    sz = f32,
466                    F32,
467                    pack = 8,
468                    dev,
469                    one_bit,
470                    1bit_u8_kernel_f32
471                )
472            }
473            (1, DType::F16) => {
474                dequant_for_dtype!(
475                    self,
476                    w = u8,
477                    sz = f16,
478                    F16,
479                    pack = 8,
480                    dev,
481                    one_bit,
482                    1bit_u8_kernel_f16
483                )
484            }
485            (1, DType::BF16) => {
486                dequant_for_dtype!(
487                    self,
488                    w = u8,
489                    sz = bf16,
490                    BF16,
491                    pack = 8,
492                    dev,
493                    one_bit,
494                    1bit_u8_kernel_bf16
495                )
496            }
497            (bits, dtype) => candle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}"),
498        };
499        inner.reshape(&self.w_shape)
500    }
501
502    fn dequantize_matmul(&self, xs: &Tensor) -> Result<Tensor> {
503        let w = self.dequantize()?;
504        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
505        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
506            w,
507            self.bias.clone(),
508        )))?;
509        unquant.forward(xs)
510    }
511
512    pub fn with_bias(mut self, bias: Tensor) -> Self {
513        self.bias = Some(bias);
514        self
515    }
516}
517
518impl QuantMethod for HqqLayer {
519    fn new(method: QuantMethodConfig) -> Result<Self>
520    where
521        Self: Sized,
522    {
523        match method {
524            QuantMethodConfig::Gguf { .. }
525            | QuantMethodConfig::Unquantized(_)
526            | QuantMethodConfig::Gptq { .. }
527            | QuantMethodConfig::Dummy
528            | QuantMethodConfig::FP8 { .. }
529            | QuantMethodConfig::Bnb { .. }
530            | QuantMethodConfig::BlockwiseFP8 { .. }
531            | QuantMethodConfig::Afq { .. } => {
532                unreachable!()
533            }
534            QuantMethodConfig::Hqq {
535                tensor,
536                bits,
537                group_size,
538                axis,
539                optimization_steps,
540                round_zeros,
541                channel_wise,
542                bias,
543            } => {
544                let cfg = HqqConfig {
545                    bits,
546                    group_size,
547                    axis,
548                    optimization_steps,
549                    round_zeros: round_zeros.unwrap_or(false),
550                    channel_wise: channel_wise.unwrap_or(true),
551                };
552
553                let this = Self::quantize(&tensor, tensor.device(), cfg)?;
554                if let Some(bias) = bias {
555                    Ok(this.with_bias(bias))
556                } else {
557                    Ok(this)
558                }
559            }
560        }
561    }
562
563    fn dequantize_w(&self) -> Result<Tensor> {
564        self.dequantize()
565    }
566
567    fn forward(&self, a: &Tensor) -> Result<Tensor> {
568        /*
569        if self.cfg.force_dequantize {
570            self.dequantize_matmul(a)
571        } else {
572            todo!()
573        } */
574        self.dequantize_matmul(a)
575    }
576
577    fn quantized_act_type(&self) -> Option<DType> {
578        Some(self.scales.dtype())
579    }
580
581    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
582        candle_core::bail!("HQQ quantization does not support adding weight delta.")
583    }
584
585    fn dtype_and_device(&self) -> (DType, Device) {
586        (self.scales.dtype(), self.scales.device().clone())
587    }
588
589    fn apply_isq(
590        self: Arc<Self>,
591        dtype: Option<IsqType>,
592        device: Device,
593        n_quantized: &AtomicUsize,
594        imatrix_weight: Option<Vec<f32>>,
595        guard: QuantizeOntoGuard,
596    ) -> Result<Arc<dyn QuantMethod>> {
597        let _acquired_quantize_guard = guard.acquire();
598        if imatrix_weight.is_some() {
599            // TODO just warn?
600            candle_core::bail!("HQQ does not support imatrix.");
601        }
602
603        n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
604        let bits = match dtype {
605            Some(IsqType::HQQ8) => HqqBits::Eight,
606            Some(IsqType::HQQ4) => HqqBits::Four,
607            // Some(IsqType::HQQ3) => HqqBits::Three,
608            // Some(IsqType::HQQ2) => HqqBits::Two,
609            // Some(IsqType::HQQ1) => HqqBits::One,
610            _ => candle_core::bail!("Expected a HQQ ISQ type."),
611        };
612        let cfg = HqqConfig {
613            bits,
614            group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
615            axis: HqqAxis::Zero,
616            optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
617            round_zeros: false,
618            channel_wise: true,
619        };
620        let dequant = self.dequantize()?;
621        let res = Self::quantize(&dequant, &device, cfg)?;
622        if let Some(ref bias) = self.bias {
623            let bias = bias
624                .to_device(&device)?
625                .to_dtype(res.dtype_and_device().0)?;
626            Ok(Arc::new(res.with_bias(bias)))
627        } else {
628            Ok(Arc::new(res))
629        }
630    }
631}
632
633// Serialization structure:
634//
635// -----------------------
636// UQFF version, u32, little endian
637// -----------------------
638// ISQ type (2 for hqq), u8, little endian
639// -----------------------
640// Whether bias data is included, u8 boolean
641// -----------------------
642// Quantized weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
643// -----------------------
644// Quantized scale tensor data generated by `serialize_tensor`. Refer to its docs for layout.
645// -----------------------
646// Quantized zeroes tensor data generated by `serialize_tensor`. Refer to its docs for layout.
647// -----------------------
648// Weight (after dequant) shape dims, u32, little endian
649// -----------------------
650// ...
651// Array (in original order): Weight (after dequant) shape dims, u32, little endian
652// ...
653// -----------------------
654// Cfg bits, u8, little endian
655// -----------------------
656// Cfg group size, u32, little endian
657// -----------------------
658// Cfg axis, u8, little endian
659// -----------------------
660// Cfg optimization steps, u32, little endian
661// -----------------------
662// Cfg round_zeros, boolean u8, little endian
663// -----------------------
664// Cfg channel_wise, boolean u8, little endian
665// -----------------------
666// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
667// -----------------------
668
669impl QuantizedSerde for HqqLayer {
670    fn isq_serde_supported(&self) -> bool {
671        true
672    }
673    fn name(&self) -> &'static str {
674        "hqq"
675    }
676    fn serialize(&self) -> Result<Cow<[u8]>> {
677        self.serialize_with_bias(self.bias.clone())
678    }
679    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
680        let mut buffer = Vec::new();
681
682        // Version is always first!
683        buffer.extend(&UQFF_VERSION.to_le_bytes());
684
685        // ISQ type for hqq is 2
686        buffer.push(QuantizedSerdeType::Hqq as u8);
687
688        // Has bias
689        buffer.push(bias.is_some() as u8);
690
691        serialize_tensor(&mut buffer, &self.w_q)?;
692        serialize_tensor(&mut buffer, &self.scales)?;
693        serialize_tensor(&mut buffer, &self.zeros)?;
694
695        let w_shape = self.w_shape.dims();
696        buffer.extend((w_shape.len() as u32).to_le_bytes());
697        for dim in w_shape {
698            buffer.extend((*dim as u32).to_le_bytes());
699        }
700
701        // Config
702        buffer.push(self.cfg.bits as u8);
703        buffer.extend(
704            &(<NonZeroUsize as Into<usize>>::into(self.cfg.group_size) as u32).to_le_bytes(),
705        );
706        buffer.push(self.cfg.axis as u8);
707        // FIXME: using 0 as a sentinel for None is OK because it really should be.
708        buffer.extend(&(self.cfg.optimization_steps.unwrap_or(0) as u32).to_le_bytes());
709        buffer.push(self.cfg.round_zeros as u8);
710        buffer.push(self.cfg.channel_wise as u8);
711
712        if let Some(bias) = &bias {
713            // Bias
714            serialize_tensor(&mut buffer, bias)?;
715        }
716
717        Ok(Cow::from(buffer))
718    }
719
720    fn deserialize(
721        data: Cow<[u8]>,
722        device: &Device,
723        _comm: &Arc<crate::Comm>,
724        guard: QuantizeOntoGuard,
725    ) -> Result<Arc<dyn QuantMethod>>
726    where
727        Self: Sized,
728    {
729        let mut buffer = Cursor::new(data);
730
731        let version = buffer.read_u32::<LittleEndian>()?;
732        if let Err(e) = version_is_compatible(version) {
733            return Err(candle_core::Error::wrap(e));
734        }
735
736        let isq_type = buffer.read_u8()? as usize;
737        if isq_type != QuantizedSerdeType::Hqq as usize {
738            candle_core::bail!(
739                "ISQ type ({isq_type}) doesn't match expected type {}",
740                QuantizedSerdeType::Hqq as usize
741            );
742        }
743
744        let has_bias = buffer.read_u8()? != 0;
745
746        let _acquired_load_guard = guard.acquire();
747        let w_q = deserialize_tensor(&mut buffer, device)?;
748        let scales = deserialize_tensor(&mut buffer, device)?;
749        let zeros = deserialize_tensor(&mut buffer, device)?;
750
751        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
752
753        let mut dims = Vec::with_capacity(n_dims);
754        for _ in 0..n_dims {
755            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
756        }
757        let w_shape = Shape::from_dims(&dims);
758
759        // TODO: keep this in sync with get_isq_type_from_uqff!
760        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
761        let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
762        let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
763        let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
764            0 => None,
765            other => Some(other),
766        };
767        let round_zeros = buffer.read_u8()? != 0;
768        let channel_wise = buffer.read_u8()? != 0;
769
770        let cfg = HqqConfig {
771            bits,
772            group_size,
773            axis,
774            optimization_steps,
775            round_zeros,
776            channel_wise,
777        };
778
779        let b = if has_bias {
780            Some(deserialize_tensor(&mut buffer, device)?)
781        } else {
782            None
783        };
784
785        Ok(Arc::new(Self {
786            w_q,
787            zeros,
788            scales,
789            bias: b,
790            w_shape,
791            cfg,
792        }))
793    }
794    fn deserialize_ext_bias(
795        data: Cow<[u8]>,
796        device: &Device,
797        guard: QuantizeOntoGuard,
798    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
799    where
800        Self: Sized,
801    {
802        let mut buffer = Cursor::new(data);
803
804        let version = buffer.read_u32::<LittleEndian>()?;
805        if let Err(e) = version_is_compatible(version) {
806            return Err(candle_core::Error::wrap(e));
807        }
808
809        let isq_type = buffer.read_u8()? as usize;
810        if isq_type != QuantizedSerdeType::Hqq as usize {
811            candle_core::bail!(
812                "ISQ type ({isq_type}) doesn't match expected type {}",
813                QuantizedSerdeType::Hqq as usize
814            );
815        }
816
817        let has_bias = buffer.read_u8()? != 0;
818
819        let _acquired_load_guard = guard.acquire();
820        let w_q = deserialize_tensor(&mut buffer, device)?;
821        let scales = deserialize_tensor(&mut buffer, device)?;
822        let zeros = deserialize_tensor(&mut buffer, device)?;
823
824        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
825
826        let mut dims = Vec::with_capacity(n_dims);
827        for _ in 0..n_dims {
828            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
829        }
830        let w_shape = Shape::from_dims(&dims);
831
832        // TODO: keep this in sync with get_isq_type_from_uqff!
833        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
834        let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
835        let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
836        let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
837            0 => None,
838            other => Some(other),
839        };
840        let round_zeros = buffer.read_u8()? != 0;
841        let channel_wise = buffer.read_u8()? != 0;
842
843        let cfg = HqqConfig {
844            bits,
845            group_size,
846            axis,
847            optimization_steps,
848            round_zeros,
849            channel_wise,
850        };
851
852        let b = if has_bias {
853            Some(deserialize_tensor(&mut buffer, device)?)
854        } else {
855            None
856        };
857
858        Ok((
859            Arc::new(Self {
860                w_q,
861                zeros,
862                scales,
863                bias: None,
864                w_shape,
865                cfg,
866            }),
867            b,
868        ))
869    }
870}
871
872impl HqqLayer {
873    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
874        let mut buffer = Cursor::new(data);
875
876        let version = buffer.read_u32::<LittleEndian>()?;
877        if let Err(e) = version_is_compatible(version) {
878            return Err(candle_core::Error::wrap(e));
879        }
880
881        let isq_type = buffer.read_u8()? as usize;
882        if isq_type != QuantizedSerdeType::Hqq as usize {
883            candle_core::bail!(
884                "ISQ type ({isq_type}) doesn't match expected type {}",
885                QuantizedSerdeType::Hqq as usize
886            );
887        }
888
889        let _has_bias = buffer.read_u8()? != 0;
890
891        fake_deserialize_tensor(&mut buffer)?;
892        fake_deserialize_tensor(&mut buffer)?;
893        fake_deserialize_tensor(&mut buffer)?;
894
895        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
896
897        let mut dims = Vec::with_capacity(n_dims);
898        for _ in 0..n_dims {
899            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
900        }
901        let _w_shape = Shape::from_dims(&dims);
902
903        // TODO: keep this in sync with get_isq_type_from_uqff!
904        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
905
906        match bits {
907            HqqBits::Eight => Ok(IsqType::HQQ8),
908            HqqBits::Four => Ok(IsqType::HQQ4),
909            HqqBits::One | HqqBits::Two | HqqBits::Three => {
910                candle_core::bail!("cannot convert hqq bits to isq type")
911            }
912        }
913    }
914}