mistralrs_quant/hqq/
mod.rs

1use byteorder::{LittleEndian, ReadBytesExt};
2use candle_core::{DType, Device, Result, Shape, Tensor};
3
4#[cfg(feature = "cuda")]
5use candle_core::{
6    cuda::{cudarc::driver::DevicePtr, CudaStorageSlice},
7    from_storage_no_op, CudaStorage, Storage,
8};
9
10#[cfg(feature = "metal")]
11use candle_core::{from_storage_no_op, Storage};
12
13use candle_nn::Linear;
14#[cfg(feature = "cuda")]
15use half::{bf16, f16};
16use std::{
17    borrow::Cow,
18    io::Cursor,
19    num::NonZeroUsize,
20    sync::{atomic::AtomicUsize, Arc},
21};
22
23use crate::{
24    utils::{
25        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
26        BitWiseOp, LeftshiftOp, UQFF_VERSION,
27    },
28    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
29    UnquantLinear,
30};
31
32#[cfg(feature = "cuda")]
33use crate::utils::get_cuda_device;
34
35#[cfg(feature = "cuda")]
36use ffi::{eight_bit, four_bit, one_bit, three_bit, two_bit};
37
38#[cfg(feature = "cuda")]
39mod ffi;
40
41#[cfg(feature = "cuda")]
42mod bitpack_ffi;
43
44#[cfg(not(feature = "cuda"))]
45mod hqq_op;
46
47mod optimize;
48mod quantize;
49
50pub(crate) const ISQ_HQQ_GROUP_SIZE: usize = 64;
51pub(crate) const ISQ_HQQ_DEFAULT_OPT_STEPS: Option<usize> = Some(10);
52pub(crate) const OPTIMIZER_HQQ_DEFAULT_STEPS: usize = 20;
53
54#[cfg(feature = "cuda")]
55macro_rules! dequant_for_dtype {
56    ($this:expr, w=$wq_t:ty, sz=$scale_t:ty, $dtype:ident, pack=$pack:expr, $dev:expr, $bit_thing:ident, $postfix:tt) => {{
57        paste::paste! {
58            let (wq, _) = $this.w_q.storage_and_layout();
59            let wq = match &*wq {
60                candle_core::Storage::Cuda(s) => s,
61                _ => candle_core::bail!("wq must be a cuda tensor"),
62            };
63            let (w_slice, _w_guard) = crate::utils::slice_ptr(wq.as_cuda_slice::<$wq_t>()?, $this.w_q.layout().start_offset());
64
65            let (scale, _) = $this.scales.storage_and_layout();
66            let scale = match &*scale {
67                candle_core::Storage::Cuda(s) => s,
68                _ => candle_core::bail!("scale must be a cuda tensor"),
69            };
70            let (scale_slice, _scale_guard) = crate::utils::slice_ptr(scale.as_cuda_slice::<$scale_t>()?, $this.scales.layout().start_offset());
71
72            let (zero, _) = $this.zeros.storage_and_layout();
73            let zero = match &*zero {
74                candle_core::Storage::Cuda(s) => s,
75                _ => candle_core::bail!("zero must be a cuda tensor"),
76            };
77            let (zero_slice, _zero_guard) = crate::utils::slice_ptr(zero.as_cuda_slice::<$scale_t>()?, $this.zeros.layout().start_offset());
78
79            let (h, w) = $this.w_q.dims2()?;
80            let num_packed_elems = $pack;
81            let out_shape = Shape::from_dims(&[num_packed_elems * h, w]);
82
83            let out = unsafe { $dev.alloc::<$scale_t>(out_shape.elem_count())? };
84            let (out_ptr, out_guard) = out.device_ptr(out.stream());
85            unsafe {
86                $bit_thing::[< dequantize_ $postfix >](
87                    w_slice as *const $wq_t,
88                    scale_slice as *const $scale_t,
89                    zero_slice as *const $scale_t,
90                    out_ptr as *mut $scale_t,
91                    h as i32,
92                    w as i32,
93                );
94            }
95            drop(out_guard);
96
97            let storage = CudaStorage {
98                slice: CudaStorageSlice::$dtype(out),
99                device: $dev.clone(),
100            };
101            let storage = Storage::Cuda(storage);
102
103            from_storage_no_op(storage, out_shape, false)
104        }
105    }};
106}
107
108#[derive(Debug, Clone, Copy)]
109pub enum HqqAxis {
110    Zero = 0,
111    One = 1,
112}
113
114impl TryFrom<usize> for HqqAxis {
115    type Error = candle_core::Error;
116    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
117        match value {
118            0 => Ok(Self::Zero),
119            1 => Ok(Self::One),
120            other => candle_core::bail!("Unexpected value for HQQ axis {other}"),
121        }
122    }
123}
124
125#[derive(Debug, Clone, Copy)]
126pub enum HqqBits {
127    Eight = 8,
128    Four = 4,
129    Three = 3,
130    Two = 2,
131    One = 1,
132}
133
134impl TryFrom<usize> for HqqBits {
135    type Error = candle_core::Error;
136    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
137        match value {
138            8 => Ok(Self::Eight),
139            4 => Ok(Self::Four),
140            3 => Ok(Self::Three),
141            2 => Ok(Self::Two),
142            1 => Ok(Self::One),
143            other => candle_core::bail!("Unexpected value for HQQ bits {other}"),
144        }
145    }
146}
147
148impl HqqBits {
149    // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/core/bitpack.py#L10
150    pub(crate) fn bitpack_type(&self) -> impl Fn(Tensor) -> Result<Tensor> {
151        match self {
152            Self::Eight => |wq: Tensor| -> Result<Tensor> {
153                #[allow(unused_variables)]
154                let device = wq.device();
155
156                #[cfg(feature = "cuda")]
157                if device.is_cuda() {
158                    // Use CUDA kernel for 8-bit (which is essentially a copy)
159                    let dev = get_cuda_device(&wq)?;
160                    let wq = wq.to_dtype(DType::U8)?;
161                    let (wq_storage, _) = wq.storage_and_layout();
162                    let wq_storage = match &*wq_storage {
163                        Storage::Cuda(s) => s,
164                        _ => candle_core::bail!("Expected CUDA storage"),
165                    };
166
167                    let output_shape = wq.shape().clone();
168                    let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
169
170                    unsafe {
171                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
172                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
173                            wq_storage.as_cuda_slice::<u8>()?,
174                            wq.layout().start_offset(),
175                        );
176
177                        bitpack_ffi::launch_pack_8bit_kernel(
178                            input_ptr as *const u8,
179                            output_ptr as *mut u8,
180                            output_shape.elem_count(),
181                            dev.cuda_stream().cu_stream(),
182                        );
183                        drop(output_guard);
184                    }
185
186                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
187                    let storage = Storage::Cuda(storage);
188                    return Ok(from_storage_no_op(storage, output_shape, false));
189                }
190
191                #[cfg(feature = "metal")]
192                if device.is_metal() {
193                    use candle_core::MetalStorage;
194
195                    let dev = device.as_metal_device()?;
196                    let command_buffer = dev.command_buffer()?;
197                    command_buffer.set_label("hqq_pack_8bit");
198
199                    let (wq_storage, _wq_layout) = wq.storage_and_layout();
200                    let wq_storage = match &*wq_storage {
201                        Storage::Metal(s) => s,
202                        _ => candle_core::bail!("Expected Metal storage"),
203                    };
204
205                    let output_shape = wq.shape().clone();
206                    let output = dev.new_buffer(
207                        output_shape.elem_count(),
208                        DType::U8,
209                        "hqq_pack_8bit_output",
210                    )?;
211
212                    crate::metal_kernels::call_hqq_pack_8bit(
213                        dev.device(),
214                        &command_buffer,
215                        &crate::metal_kernels::Kernels::new(),
216                        wq_storage.buffer(),
217                        &output,
218                        output_shape.elem_count(),
219                    )
220                    .map_err(candle_core::Error::wrap)?;
221
222                    let storage = MetalStorage::new(
223                        output,
224                        dev.clone(),
225                        output_shape.elem_count(),
226                        DType::U8,
227                    );
228                    let storage = Storage::Metal(storage);
229
230                    return Ok(from_storage_no_op(storage, output_shape, false));
231                }
232
233                wq.to_dtype(DType::U8)
234            },
235            Self::Four => |wq_in: Tensor| -> Result<Tensor> {
236                #[allow(unused_variables)]
237                let device = wq_in.device();
238
239                #[cfg(feature = "cuda")]
240                if device.is_cuda() {
241                    // Use CUDA kernel for 4-bit packing
242                    let dev = get_cuda_device(&wq_in)?;
243                    let wq = wq_in.to_dtype(DType::U8)?;
244                    let (wq_storage, _) = wq.storage_and_layout();
245                    let wq_storage = match &*wq_storage {
246                        Storage::Cuda(s) => s,
247                        _ => candle_core::bail!("Expected CUDA storage"),
248                    };
249
250                    let output_height = wq.dims()[0] / 2;
251                    let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
252                    let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
253
254                    unsafe {
255                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
256                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
257                            wq_storage.as_cuda_slice::<u8>()?,
258                            wq.layout().start_offset(),
259                        );
260
261                        bitpack_ffi::launch_pack_4bit_kernel(
262                            input_ptr as *const u8,
263                            output_ptr as *mut u8,
264                            wq.dims()[0],
265                            wq.dims()[1],
266                            dev.cuda_stream().cu_stream(),
267                        );
268                        drop(output_guard);
269                    }
270
271                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
272                    let storage = Storage::Cuda(storage);
273                    return Ok(from_storage_no_op(storage, output_shape, false));
274                }
275
276                #[cfg(feature = "metal")]
277                if device.is_metal() {
278                    use candle_core::MetalStorage;
279
280                    let dev = device.as_metal_device()?;
281                    let command_buffer = dev.command_buffer()?;
282                    command_buffer.set_label("hqq_pack_4bit");
283
284                    let wq = wq_in.to_dtype(DType::U8)?;
285                    let (wq_storage, _wq_layout) = wq.storage_and_layout();
286                    let wq_storage = match &*wq_storage {
287                        Storage::Metal(s) => s,
288                        _ => candle_core::bail!("Expected Metal storage"),
289                    };
290
291                    let output_height = wq.dims()[0] / 2;
292                    let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
293                    let output = dev.new_buffer(
294                        output_shape.elem_count(),
295                        DType::U8,
296                        "hqq_pack_4bit_output",
297                    )?;
298
299                    crate::metal_kernels::call_hqq_pack_4bit(
300                        dev.device(),
301                        &command_buffer,
302                        &crate::metal_kernels::Kernels::new(),
303                        wq_storage.buffer(),
304                        &output,
305                        wq.dims()[0],
306                        wq.dims()[1],
307                    )
308                    .map_err(candle_core::Error::wrap)?;
309
310                    let storage = MetalStorage::new(
311                        output,
312                        dev.clone(),
313                        output_shape.elem_count(),
314                        DType::U8,
315                    );
316                    let storage = Storage::Metal(storage);
317
318                    return Ok(from_storage_no_op(storage, output_shape, false));
319                }
320
321                // CPU fallback
322                let wq = wq_in.to_dtype(DType::U8)?;
323                let step = (wq.dims()[0] as f64 / 2.) as usize;
324
325                let a = wq.narrow(0, 0, step)?;
326                let b = wq.narrow(0, step, step)?;
327                a.leftshift(4)?.bitwise_or(&b)
328            },
329            Self::Two => |wq_in: Tensor| -> Result<Tensor> {
330                #[allow(unused_variables)]
331                let device = wq_in.device();
332
333                #[cfg(feature = "cuda")]
334                if device.is_cuda() {
335                    // Use CUDA kernel for 2-bit packing
336                    let dev = get_cuda_device(&wq_in)?;
337                    let wq = wq_in.to_dtype(DType::U8)?;
338                    let (wq_storage, _) = wq.storage_and_layout();
339                    let wq_storage = match &*wq_storage {
340                        Storage::Cuda(s) => s,
341                        _ => candle_core::bail!("Expected CUDA storage"),
342                    };
343
344                    let output_height = wq.dims()[0] / 4;
345                    let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
346                    let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
347
348                    unsafe {
349                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
350                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
351                            wq_storage.as_cuda_slice::<u8>()?,
352                            wq.layout().start_offset(),
353                        );
354
355                        bitpack_ffi::launch_pack_2bit_kernel(
356                            input_ptr as *const u8,
357                            output_ptr as *mut u8,
358                            wq.dims()[0],
359                            wq.dims()[1],
360                            dev.cuda_stream().cu_stream(),
361                        );
362                        drop(output_guard);
363                    }
364
365                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
366                    let storage = Storage::Cuda(storage);
367                    Ok(from_storage_no_op(storage, output_shape, false))
368                } else {
369                    // CPU fallback
370                    let wq = wq_in.to_dtype(DType::U8)?;
371                    let step = (wq.dims()[0] as f64 / 4.) as usize;
372
373                    let a = wq.narrow(0, 0, step)?;
374                    let b = wq.narrow(0, step, step)?;
375                    let c = wq.narrow(0, step * 2, step)?;
376                    let d = wq.narrow(0, step * 3, step)?;
377
378                    a.leftshift(6)?
379                        .bitwise_or(&b.leftshift(4)?)?
380                        .bitwise_or(&c.leftshift(2)?)?
381                        .bitwise_or(&d)
382                }
383                #[cfg(not(feature = "cuda"))]
384                {
385                    let wq = wq_in.to_dtype(DType::U8)?;
386                    let step = (wq.dims()[0] as f64 / 4.) as usize;
387
388                    let a = wq.narrow(0, 0, step)?;
389                    let b = wq.narrow(0, step, step)?;
390                    let c = wq.narrow(0, step * 2, step)?;
391                    let d = wq.narrow(0, step * 3, step)?;
392
393                    a.leftshift(6)?
394                        .bitwise_or(&b.leftshift(4)?)?
395                        .bitwise_or(&c.leftshift(2)?)?
396                        .bitwise_or(&d)
397                }
398            },
399            Self::Three => |wq_in: Tensor| -> Result<Tensor> {
400                let device = wq_in.device();
401
402                // Pad input to multiple of 10
403                let padded_height = (10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize;
404                let wq = Tensor::zeros((padded_height, wq_in.dims()[1]), DType::U32, device)?;
405                let wq = wq.slice_assign(
406                    &[0..wq_in.dims()[0], 0..wq.dims()[1]],
407                    &wq_in.to_dtype(DType::U32)?,
408                )?;
409
410                #[cfg(feature = "cuda")]
411                if device.is_cuda() {
412                    // Use CUDA kernel for efficient 3-bit packing
413                    let dev = get_cuda_device(&wq)?;
414                    let (wq_storage, _) = wq.storage_and_layout();
415                    let wq_storage = match &*wq_storage {
416                        Storage::Cuda(s) => s,
417                        _ => candle_core::bail!("Expected CUDA storage"),
418                    };
419
420                    let output_height = padded_height / 10;
421                    let output_shape = Shape::from_dims(&[output_height, wq_in.dims()[1]]);
422                    let output = unsafe { dev.alloc::<i32>(output_shape.elem_count())? };
423
424                    unsafe {
425                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
426                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
427                            wq_storage.as_cuda_slice::<u32>()?,
428                            wq.layout().start_offset(),
429                        );
430
431                        bitpack_ffi::launch_pack_3bit_kernel(
432                            input_ptr as *const u32,
433                            output_ptr as *mut i32,
434                            padded_height,
435                            wq_in.dims()[1],
436                            dev.cuda_stream().cu_stream(),
437                        );
438                        drop(output_guard);
439                    }
440
441                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
442                    let storage = Storage::Cuda(storage);
443                    return Ok(from_storage_no_op(storage, output_shape, false));
444                }
445
446                // CPU fallback implementation
447                let wq = if wq.device().is_metal() {
448                    // Metal doesn't support direct U32 to I32 conversion, use CPU as intermediate
449                    let cpu_wq = wq.to_device(&Device::Cpu)?;
450                    cpu_wq.to_dtype(DType::I32)?.to_device(wq.device())?
451                } else {
452                    wq.to_dtype(DType::I32)?
453                };
454                let step = (wq.dims()[0] as f64 / 10.) as usize;
455
456                let a = wq.narrow(0, 0, step)?;
457                let b = wq.narrow(0, step, step)?;
458                let c = wq.narrow(0, step * 2, step)?;
459                let d = wq.narrow(0, step * 3, step)?;
460                let e = wq.narrow(0, step * 4, step)?;
461                let f = wq.narrow(0, step * 5, step)?;
462                let g = wq.narrow(0, step * 6, step)?;
463                let h = wq.narrow(0, step * 7, step)?;
464                let i = wq.narrow(0, step * 8, step)?;
465                let j = wq.narrow(0, step * 9, step)?;
466
467                a.leftshift(27)?
468                    .bitwise_or(&b.leftshift(24)?)?
469                    .bitwise_or(&c.leftshift(21)?)?
470                    .bitwise_or(&d.leftshift(18)?)?
471                    .bitwise_or(&e.leftshift(15)?)?
472                    .bitwise_or(&f.leftshift(12)?)?
473                    .bitwise_or(&g.leftshift(9)?)?
474                    .bitwise_or(&h.leftshift(6)?)?
475                    .bitwise_or(&i.leftshift(3)?)?
476                    .bitwise_or(&j)
477            },
478            Self::One => |wq_in: Tensor| -> Result<Tensor> {
479                #[allow(unused_variables)]
480                let device = wq_in.device();
481
482                #[cfg(feature = "cuda")]
483                if device.is_cuda() {
484                    // Use CUDA kernel for 1-bit packing
485                    let dev = get_cuda_device(&wq_in)?;
486                    let wq = wq_in.to_dtype(DType::U8)?;
487                    let (wq_storage, _) = wq.storage_and_layout();
488                    let wq_storage = match &*wq_storage {
489                        Storage::Cuda(s) => s,
490                        _ => candle_core::bail!("Expected CUDA storage"),
491                    };
492
493                    let output_height = wq.dims()[0] / 8;
494                    let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
495                    let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
496
497                    unsafe {
498                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
499                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
500                            wq_storage.as_cuda_slice::<u8>()?,
501                            wq.layout().start_offset(),
502                        );
503
504                        bitpack_ffi::launch_pack_1bit_kernel(
505                            input_ptr as *const u8,
506                            output_ptr as *mut u8,
507                            wq.dims()[0],
508                            wq.dims()[1],
509                            dev.cuda_stream().cu_stream(),
510                        );
511                        drop(output_guard);
512                    }
513
514                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
515                    let storage = Storage::Cuda(storage);
516                    Ok(from_storage_no_op(storage, output_shape, false))
517                } else {
518                    // CPU fallback
519                    let wq = wq_in.to_dtype(DType::U8)?;
520                    let step = (wq.dims()[0] as f64 / 8.) as usize;
521
522                    let a = wq.narrow(0, 0, step)?;
523                    let b = wq.narrow(0, step, step)?;
524                    let c = wq.narrow(0, step * 2, step)?;
525                    let d = wq.narrow(0, step * 3, step)?;
526                    let e = wq.narrow(0, step * 4, step)?;
527                    let f = wq.narrow(0, step * 5, step)?;
528                    let g = wq.narrow(0, step * 6, step)?;
529                    let h = wq.narrow(0, step * 7, step)?;
530
531                    a.leftshift(7)?
532                        .bitwise_or(&b.leftshift(6)?)?
533                        .bitwise_or(&c.leftshift(5)?)?
534                        .bitwise_or(&d.leftshift(4)?)?
535                        .bitwise_or(&e.leftshift(3)?)?
536                        .bitwise_or(&f.leftshift(2)?)?
537                        .bitwise_or(&g.leftshift(1)?)?
538                        .bitwise_or(&h)
539                }
540                #[cfg(not(feature = "cuda"))]
541                {
542                    let wq = wq_in.to_dtype(DType::U8)?;
543                    let step = (wq.dims()[0] as f64 / 8.) as usize;
544
545                    let a = wq.narrow(0, 0, step)?;
546                    let b = wq.narrow(0, step, step)?;
547                    let c = wq.narrow(0, step * 2, step)?;
548                    let d = wq.narrow(0, step * 3, step)?;
549                    let e = wq.narrow(0, step * 4, step)?;
550                    let f = wq.narrow(0, step * 5, step)?;
551                    let g = wq.narrow(0, step * 6, step)?;
552                    let h = wq.narrow(0, step * 7, step)?;
553
554                    a.leftshift(7)?
555                        .bitwise_or(&b.leftshift(6)?)?
556                        .bitwise_or(&c.leftshift(5)?)?
557                        .bitwise_or(&d.leftshift(4)?)?
558                        .bitwise_or(&e.leftshift(3)?)?
559                        .bitwise_or(&f.leftshift(2)?)?
560                        .bitwise_or(&g.leftshift(1)?)?
561                        .bitwise_or(&h)
562                }
563            },
564        }
565    }
566}
567
568#[derive(Debug, Clone, Copy)]
569pub struct HqqConfig {
570    pub bits: HqqBits,
571    pub group_size: NonZeroUsize,
572    pub axis: HqqAxis,
573    pub optimization_steps: Option<usize>,
574    pub round_zeros: bool,  // default false
575    pub channel_wise: bool, // default true
576}
577
578#[derive(Debug)]
579pub struct HqqLayer {
580    pub(crate) w_q: Tensor,
581    pub(crate) zeros: Tensor,
582    pub(crate) scales: Tensor,
583    pub(crate) bias: Option<Tensor>,
584    pub(crate) w_shape: Shape,
585    pub(crate) cfg: HqqConfig,
586}
587
588impl HqqLayer {
589    /// Dequantize `self` into a tensor of shape `scales` or `zeros`.
590    #[cfg(not(feature = "cuda"))]
591    fn dequantize(&self) -> Result<Tensor> {
592        use crate::hqq::hqq_op::{Dequant1Bit, Dequant2Bit, Dequant3Bit, Dequant4Bit, Dequant8Bit};
593
594        match (self.scales.dtype(), self.zeros.dtype()) {
595            (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
596            (a, b) => {
597                candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
598            }
599        }
600        if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
601        {
602            candle_core::bail!("All tensors must be contiguous!");
603        }
604        if self.cfg.axis as usize != 0 {
605            candle_core::bail!(
606                "CPU HQQ dequantization requires axis == 0, got {}.",
607                self.cfg.axis as usize
608            );
609        }
610        let (h, w) = self.w_q.dims2()?;
611
612        match self.cfg.bits as usize {
613            8 => self
614                .w_q
615                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant8Bit { h, w })?
616                .reshape(&self.w_shape),
617            4 => self
618                .w_q
619                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant4Bit { h, w })?
620                .reshape(&self.w_shape),
621            3 => self
622                .w_q
623                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant3Bit { h, w })?
624                .reshape(&self.w_shape),
625            2 => self
626                .w_q
627                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant2Bit { h, w })?
628                .reshape(&self.w_shape),
629            1 => self
630                .w_q
631                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant1Bit { h, w })?
632                .reshape(&self.w_shape),
633            b => candle_core::bail!("Unreachable bits {b}"),
634        }
635    }
636
637    /// Dequantize `self` into a tensor of shape `scales` or `zeros`.
638    #[cfg(feature = "cuda")]
639    fn dequantize(&self) -> Result<Tensor> {
640        match (self.scales.dtype(), self.zeros.dtype()) {
641            (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
642            (a, b) => {
643                candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
644            }
645        }
646        if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
647        {
648            candle_core::bail!("All tensors must be contiguous!");
649        }
650        if self.cfg.axis as usize != 0 {
651            candle_core::bail!(
652                "CUDA HQQ dequantization requires axis == 0, got {}.",
653                self.cfg.axis as usize
654            );
655        }
656        let dev = get_cuda_device(&self.w_q)?;
657
658        let inner = match (self.cfg.bits as usize, self.scales.dtype()) {
659            // 8 bits
660            (8, DType::F32) => {
661                dequant_for_dtype!(
662                    self,
663                    w = u8,
664                    sz = f32,
665                    F32,
666                    pack = 1,
667                    dev,
668                    eight_bit,
669                    8bit_u8_kernel_f32
670                )
671            }
672            (8, DType::F16) => {
673                dequant_for_dtype!(
674                    self,
675                    w = u8,
676                    sz = f16,
677                    F16,
678                    pack = 1,
679                    dev,
680                    eight_bit,
681                    8bit_u8_kernel_f16
682                )
683            }
684            (8, DType::BF16) => {
685                dequant_for_dtype!(
686                    self,
687                    w = u8,
688                    sz = bf16,
689                    BF16,
690                    pack = 1,
691                    dev,
692                    eight_bit,
693                    8bit_u8_kernel_bf16
694                )
695            }
696
697            // 4 bits
698            (4, DType::F32) => {
699                dequant_for_dtype!(
700                    self,
701                    w = u8,
702                    sz = f32,
703                    F32,
704                    pack = 2,
705                    dev,
706                    four_bit,
707                    4bit_u8_kernel_f32
708                )
709            }
710            (4, DType::F16) => {
711                dequant_for_dtype!(
712                    self,
713                    w = u8,
714                    sz = f16,
715                    F16,
716                    pack = 2,
717                    dev,
718                    four_bit,
719                    4bit_u8_kernel_f16
720                )
721            }
722            (4, DType::BF16) => {
723                dequant_for_dtype!(
724                    self,
725                    w = u8,
726                    sz = bf16,
727                    BF16,
728                    pack = 2,
729                    dev,
730                    four_bit,
731                    4bit_u8_kernel_bf16
732                )
733            }
734
735            // 3 bits
736            // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/kernels/hqq_aten_cuda.cpp#L42-L45
737            (3, DType::F32) => {
738                let res = dequant_for_dtype!(
739                    self,
740                    w = i32,
741                    sz = f32,
742                    F32,
743                    pack = 10,
744                    dev,
745                    three_bit,
746                    3bit_32_kernel_f32
747                );
748                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
749            }
750            (3, DType::F16) => {
751                let res = dequant_for_dtype!(
752                    self,
753                    w = i32,
754                    sz = f16,
755                    F16,
756                    pack = 10,
757                    dev,
758                    three_bit,
759                    3bit_32_kernel_f16
760                );
761                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
762            }
763            (3, DType::BF16) => {
764                let res = dequant_for_dtype!(
765                    self,
766                    w = i32,
767                    sz = bf16,
768                    BF16,
769                    pack = 10,
770                    dev,
771                    three_bit,
772                    3bit_32_kernel_bf16
773                );
774                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
775            }
776
777            // 2 bits
778            (2, DType::F32) => {
779                dequant_for_dtype!(
780                    self,
781                    w = u8,
782                    sz = f32,
783                    F32,
784                    pack = 4,
785                    dev,
786                    two_bit,
787                    2bit_u8_kernel_f32
788                )
789            }
790            (2, DType::F16) => {
791                dequant_for_dtype!(
792                    self,
793                    w = u8,
794                    sz = f16,
795                    F16,
796                    pack = 4,
797                    dev,
798                    two_bit,
799                    2bit_u8_kernel_f16
800                )
801            }
802            (2, DType::BF16) => {
803                dequant_for_dtype!(
804                    self,
805                    w = u8,
806                    sz = bf16,
807                    BF16,
808                    pack = 4,
809                    dev,
810                    two_bit,
811                    2bit_u8_kernel_bf16
812                )
813            }
814
815            // 1 bit
816            (1, DType::F32) => {
817                dequant_for_dtype!(
818                    self,
819                    w = u8,
820                    sz = f32,
821                    F32,
822                    pack = 8,
823                    dev,
824                    one_bit,
825                    1bit_u8_kernel_f32
826                )
827            }
828            (1, DType::F16) => {
829                dequant_for_dtype!(
830                    self,
831                    w = u8,
832                    sz = f16,
833                    F16,
834                    pack = 8,
835                    dev,
836                    one_bit,
837                    1bit_u8_kernel_f16
838                )
839            }
840            (1, DType::BF16) => {
841                dequant_for_dtype!(
842                    self,
843                    w = u8,
844                    sz = bf16,
845                    BF16,
846                    pack = 8,
847                    dev,
848                    one_bit,
849                    1bit_u8_kernel_bf16
850                )
851            }
852            (bits, dtype) => candle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}"),
853        };
854        inner.reshape(&self.w_shape)
855    }
856
857    fn dequantize_matmul(&self, xs: &Tensor) -> Result<Tensor> {
858        let w = self.dequantize()?;
859        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
860        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
861            w,
862            self.bias.clone(),
863        )))?;
864        unquant.forward(xs)
865    }
866
867    pub fn with_bias(mut self, bias: Tensor) -> Self {
868        self.bias = Some(bias);
869        self
870    }
871}
872
873impl QuantMethod for HqqLayer {
874    fn new(method: QuantMethodConfig) -> Result<Self>
875    where
876        Self: Sized,
877    {
878        match method {
879            QuantMethodConfig::Gguf { .. }
880            | QuantMethodConfig::Unquantized(_)
881            | QuantMethodConfig::GptqAwq { .. }
882            | QuantMethodConfig::Dummy
883            | QuantMethodConfig::FP8 { .. }
884            | QuantMethodConfig::Bnb { .. }
885            | QuantMethodConfig::BlockwiseFP8 { .. }
886            | QuantMethodConfig::Afq { .. }
887            | QuantMethodConfig::MXFP4 { .. } => {
888                unreachable!()
889            }
890            QuantMethodConfig::Hqq {
891                tensor,
892                bits,
893                group_size,
894                axis,
895                optimization_steps,
896                round_zeros,
897                channel_wise,
898                bias,
899            } => {
900                let cfg = HqqConfig {
901                    bits,
902                    group_size,
903                    axis,
904                    optimization_steps,
905                    round_zeros: round_zeros.unwrap_or(false),
906                    channel_wise: channel_wise.unwrap_or(true),
907                };
908
909                let this = Self::quantize(&tensor, tensor.device(), cfg)?;
910                if let Some(bias) = bias {
911                    Ok(this.with_bias(bias))
912                } else {
913                    Ok(this)
914                }
915            }
916        }
917    }
918
919    fn dequantize_w(&self) -> Result<Tensor> {
920        self.dequantize()
921    }
922
923    fn forward(&self, a: &Tensor) -> Result<Tensor> {
924        /*
925        if self.cfg.force_dequantize {
926            self.dequantize_matmul(a)
927        } else {
928            todo!()
929        } */
930        self.dequantize_matmul(a)
931    }
932
933    fn quantized_act_type(&self) -> Option<DType> {
934        Some(self.scales.dtype())
935    }
936
937    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
938        candle_core::bail!("HQQ quantization does not support adding weight delta.")
939    }
940
941    fn dtype_and_device(&self) -> (DType, Device) {
942        (self.scales.dtype(), self.scales.device().clone())
943    }
944
945    fn apply_isq(
946        self: Arc<Self>,
947        dtype: Option<IsqType>,
948        device: Device,
949        n_quantized: &AtomicUsize,
950        imatrix_weight: Option<Vec<f32>>,
951        guard: QuantizeOntoGuard,
952    ) -> Result<Arc<dyn QuantMethod>> {
953        let _acquired_quantize_guard = guard.acquire(&device);
954        if imatrix_weight.is_some() {
955            // TODO just warn?
956            candle_core::bail!("HQQ does not support imatrix.");
957        }
958
959        n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
960        let bits = match dtype {
961            Some(IsqType::HQQ8) => HqqBits::Eight,
962            Some(IsqType::HQQ4) => HqqBits::Four,
963            // Some(IsqType::HQQ3) => HqqBits::Three,
964            // Some(IsqType::HQQ2) => HqqBits::Two,
965            // Some(IsqType::HQQ1) => HqqBits::One,
966            _ => candle_core::bail!("Expected a HQQ ISQ type."),
967        };
968        let cfg = HqqConfig {
969            bits,
970            group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
971            axis: HqqAxis::Zero,
972            optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
973            round_zeros: false,
974            channel_wise: true,
975        };
976        let dequant = self.dequantize()?;
977        let res = Self::quantize(&dequant, &device, cfg)?;
978        if let Some(ref bias) = self.bias {
979            let bias = bias
980                .to_device(&device)?
981                .to_dtype(res.dtype_and_device().0)?;
982            Ok(Arc::new(res.with_bias(bias)))
983        } else {
984            Ok(Arc::new(res))
985        }
986    }
987}
988
989// Serialization structure:
990//
991// -----------------------
992// UQFF version, u32, little endian
993// -----------------------
994// ISQ type (2 for hqq), u8, little endian
995// -----------------------
996// Whether bias data is included, u8 boolean
997// -----------------------
998// Quantized weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
999// -----------------------
1000// Quantized scale tensor data generated by `serialize_tensor`. Refer to its docs for layout.
1001// -----------------------
1002// Quantized zeroes tensor data generated by `serialize_tensor`. Refer to its docs for layout.
1003// -----------------------
1004// Weight (after dequant) shape dims, u32, little endian
1005// -----------------------
1006// ...
1007// Array (in original order): Weight (after dequant) shape dims, u32, little endian
1008// ...
1009// -----------------------
1010// Cfg bits, u8, little endian
1011// -----------------------
1012// Cfg group size, u32, little endian
1013// -----------------------
1014// Cfg axis, u8, little endian
1015// -----------------------
1016// Cfg optimization steps, u32, little endian
1017// -----------------------
1018// Cfg round_zeros, boolean u8, little endian
1019// -----------------------
1020// Cfg channel_wise, boolean u8, little endian
1021// -----------------------
1022// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
1023// -----------------------
1024
1025impl QuantizedSerde for HqqLayer {
1026    fn isq_serde_supported(&self) -> bool {
1027        true
1028    }
1029    fn name(&self) -> &'static str {
1030        "hqq"
1031    }
1032    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
1033        self.serialize_with_bias(self.bias.clone())
1034    }
1035    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
1036        let mut buffer = Vec::new();
1037
1038        // Version is always first!
1039        buffer.extend(&UQFF_VERSION.to_le_bytes());
1040
1041        // ISQ type for hqq is 2
1042        buffer.push(QuantizedSerdeType::Hqq as u8);
1043
1044        // Has bias
1045        buffer.push(bias.is_some() as u8);
1046
1047        serialize_tensor(&mut buffer, &self.w_q)?;
1048        serialize_tensor(&mut buffer, &self.scales)?;
1049        serialize_tensor(&mut buffer, &self.zeros)?;
1050
1051        let w_shape = self.w_shape.dims();
1052        let shape_len = w_shape.len();
1053        if shape_len > u32::MAX as usize {
1054            candle_core::bail!(
1055                "Weight tensor has too many dimensions for UQFF format: {} exceeds u32::MAX",
1056                shape_len
1057            );
1058        }
1059        buffer.extend((shape_len as u32).to_le_bytes());
1060        for dim in w_shape {
1061            if *dim > u32::MAX as usize {
1062                candle_core::bail!(
1063                    "Weight tensor dimension too large for UQFF format: {} exceeds u32::MAX",
1064                    dim
1065                );
1066            }
1067            buffer.extend((*dim as u32).to_le_bytes());
1068        }
1069
1070        // Config
1071        buffer.push(self.cfg.bits as u8);
1072        let group_size = <NonZeroUsize as Into<usize>>::into(self.cfg.group_size);
1073        if group_size > u32::MAX as usize {
1074            candle_core::bail!(
1075                "HQQ group size too large for UQFF format: {} exceeds u32::MAX",
1076                group_size
1077            );
1078        }
1079        buffer.extend(&(group_size as u32).to_le_bytes());
1080        buffer.push(self.cfg.axis as u8);
1081        // NOTE: using 0 as a sentinel for None. This means legitimate 0 values cannot be distinguished from None.
1082        // This is acceptable because 0 optimization steps would be functionally equivalent to None.
1083        let opt_steps = self.cfg.optimization_steps.unwrap_or(0);
1084        if opt_steps > u32::MAX as usize {
1085            candle_core::bail!(
1086                "HQQ optimization steps too large for UQFF format: {} exceeds u32::MAX",
1087                opt_steps
1088            );
1089        }
1090        buffer.extend(&(opt_steps as u32).to_le_bytes());
1091        buffer.push(self.cfg.round_zeros as u8);
1092        buffer.push(self.cfg.channel_wise as u8);
1093
1094        if let Some(bias) = &bias {
1095            // Bias
1096            serialize_tensor(&mut buffer, bias)?;
1097        }
1098
1099        Ok(Cow::from(buffer))
1100    }
1101
1102    fn deserialize(
1103        data: Cow<[u8]>,
1104        device: &Device,
1105        _comm: &Arc<crate::Comm>,
1106        guard: QuantizeOntoGuard,
1107    ) -> Result<Arc<dyn QuantMethod>>
1108    where
1109        Self: Sized,
1110    {
1111        let mut buffer = Cursor::new(data);
1112
1113        let version = buffer.read_u32::<LittleEndian>()?;
1114        if let Err(e) = version_is_compatible(version) {
1115            return Err(candle_core::Error::wrap(e));
1116        }
1117
1118        let isq_type = buffer.read_u8()? as usize;
1119        if isq_type != QuantizedSerdeType::Hqq as usize {
1120            candle_core::bail!(
1121                "ISQ type ({isq_type}) doesn't match expected type {}",
1122                QuantizedSerdeType::Hqq as usize
1123            );
1124        }
1125
1126        let has_bias = buffer.read_u8()? != 0;
1127
1128        let _acquired_load_guard = guard.acquire(device);
1129        let w_q = deserialize_tensor(&mut buffer, device)?;
1130        let scales = deserialize_tensor(&mut buffer, device)?;
1131        let zeros = deserialize_tensor(&mut buffer, device)?;
1132
1133        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
1134
1135        let mut dims = Vec::with_capacity(n_dims);
1136        for _ in 0..n_dims {
1137            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
1138        }
1139        let w_shape = Shape::from_dims(&dims);
1140
1141        // TODO: keep this in sync with get_isq_type_from_uqff!
1142        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
1143        let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
1144        let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
1145        let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
1146            0 => None,
1147            other => Some(other),
1148        };
1149        let round_zeros = buffer.read_u8()? != 0;
1150        let channel_wise = buffer.read_u8()? != 0;
1151
1152        let cfg = HqqConfig {
1153            bits,
1154            group_size,
1155            axis,
1156            optimization_steps,
1157            round_zeros,
1158            channel_wise,
1159        };
1160
1161        let b = if has_bias {
1162            Some(deserialize_tensor(&mut buffer, device)?)
1163        } else {
1164            None
1165        };
1166
1167        Ok(Arc::new(Self {
1168            w_q,
1169            zeros,
1170            scales,
1171            bias: b,
1172            w_shape,
1173            cfg,
1174        }))
1175    }
1176    fn deserialize_ext_bias(
1177        data: Cow<[u8]>,
1178        device: &Device,
1179        guard: QuantizeOntoGuard,
1180    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
1181    where
1182        Self: Sized,
1183    {
1184        let mut buffer = Cursor::new(data);
1185
1186        let version = buffer.read_u32::<LittleEndian>()?;
1187        if let Err(e) = version_is_compatible(version) {
1188            return Err(candle_core::Error::wrap(e));
1189        }
1190
1191        let isq_type = buffer.read_u8()? as usize;
1192        if isq_type != QuantizedSerdeType::Hqq as usize {
1193            candle_core::bail!(
1194                "ISQ type ({isq_type}) doesn't match expected type {}",
1195                QuantizedSerdeType::Hqq as usize
1196            );
1197        }
1198
1199        let has_bias = buffer.read_u8()? != 0;
1200
1201        let _acquired_load_guard = guard.acquire(device);
1202        let w_q = deserialize_tensor(&mut buffer, device)?;
1203        let scales = deserialize_tensor(&mut buffer, device)?;
1204        let zeros = deserialize_tensor(&mut buffer, device)?;
1205
1206        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
1207
1208        let mut dims = Vec::with_capacity(n_dims);
1209        for _ in 0..n_dims {
1210            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
1211        }
1212        let w_shape = Shape::from_dims(&dims);
1213
1214        // TODO: keep this in sync with get_isq_type_from_uqff!
1215        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
1216        let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
1217        let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
1218        let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
1219            0 => None,
1220            other => Some(other),
1221        };
1222        let round_zeros = buffer.read_u8()? != 0;
1223        let channel_wise = buffer.read_u8()? != 0;
1224
1225        let cfg = HqqConfig {
1226            bits,
1227            group_size,
1228            axis,
1229            optimization_steps,
1230            round_zeros,
1231            channel_wise,
1232        };
1233
1234        let b = if has_bias {
1235            Some(deserialize_tensor(&mut buffer, device)?)
1236        } else {
1237            None
1238        };
1239
1240        Ok((
1241            Arc::new(Self {
1242                w_q,
1243                zeros,
1244                scales,
1245                bias: None,
1246                w_shape,
1247                cfg,
1248            }),
1249            b,
1250        ))
1251    }
1252}
1253
1254impl HqqLayer {
1255    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
1256        let mut buffer = Cursor::new(data);
1257
1258        let version = buffer.read_u32::<LittleEndian>()?;
1259        if let Err(e) = version_is_compatible(version) {
1260            return Err(candle_core::Error::wrap(e));
1261        }
1262
1263        let isq_type = buffer.read_u8()? as usize;
1264        if isq_type != QuantizedSerdeType::Hqq as usize {
1265            candle_core::bail!(
1266                "ISQ type ({isq_type}) doesn't match expected type {}",
1267                QuantizedSerdeType::Hqq as usize
1268            );
1269        }
1270
1271        let _has_bias = buffer.read_u8()? != 0;
1272
1273        fake_deserialize_tensor(&mut buffer)?;
1274        fake_deserialize_tensor(&mut buffer)?;
1275        fake_deserialize_tensor(&mut buffer)?;
1276
1277        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
1278
1279        let mut dims = Vec::with_capacity(n_dims);
1280        for _ in 0..n_dims {
1281            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
1282        }
1283        let _w_shape = Shape::from_dims(&dims);
1284
1285        // TODO: keep this in sync with get_isq_type_from_uqff!
1286        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
1287
1288        match bits {
1289            HqqBits::Eight => Ok(IsqType::HQQ8),
1290            HqqBits::Four => Ok(IsqType::HQQ4),
1291            HqqBits::One | HqqBits::Two | HqqBits::Three => {
1292                candle_core::bail!("cannot convert hqq bits to isq type")
1293            }
1294        }
1295    }
1296}