mistralrs_quant/hqq/
mod.rs

1use byteorder::{LittleEndian, ReadBytesExt};
2use candle_core::{DType, Device, Result, Shape, Tensor};
3
4#[cfg(feature = "cuda")]
5use candle_core::{
6    cuda::{cudarc::driver::DevicePtr, CudaStorageSlice},
7    from_storage_no_op, CudaStorage, Storage,
8};
9
10#[cfg(feature = "metal")]
11use candle_core::{from_storage_no_op, Storage};
12
13use candle_nn::Linear;
14#[cfg(feature = "cuda")]
15use half::{bf16, f16};
16use std::{
17    borrow::Cow,
18    io::Cursor,
19    num::NonZeroUsize,
20    sync::{atomic::AtomicUsize, Arc},
21};
22
23use crate::{
24    utils::{
25        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
26        BitWiseOp, LeftshiftOp, UQFF_VERSION,
27    },
28    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
29    UnquantLinear,
30};
31
32#[cfg(feature = "cuda")]
33use crate::utils::get_cuda_device;
34
35#[cfg(feature = "cuda")]
36use ffi::{eight_bit, four_bit, one_bit, three_bit, two_bit};
37
38#[cfg(feature = "cuda")]
39mod ffi;
40
41#[cfg(feature = "cuda")]
42mod bitpack_ffi;
43
44#[cfg(not(feature = "cuda"))]
45mod hqq_op;
46
47mod optimize;
48mod quantize;
49
50pub(crate) const ISQ_HQQ_GROUP_SIZE: usize = 64;
51pub(crate) const ISQ_HQQ_DEFAULT_OPT_STEPS: Option<usize> = Some(10);
52pub(crate) const OPTIMIZER_HQQ_DEFAULT_STEPS: usize = 20;
53
54#[cfg(feature = "cuda")]
55macro_rules! dequant_for_dtype {
56    ($this:expr, w=$wq_t:ty, sz=$scale_t:ty, $dtype:ident, pack=$pack:expr, $dev:expr, $bit_thing:ident, $postfix:tt) => {{
57        paste::paste! {
58            let (wq, _) = $this.w_q.storage_and_layout();
59            let wq = match &*wq {
60                candle_core::Storage::Cuda(s) => s,
61                _ => candle_core::bail!("wq must be a cuda tensor"),
62            };
63            let (w_slice, _w_guard) = crate::utils::slice_ptr(wq.as_cuda_slice::<$wq_t>()?, $this.w_q.layout().start_offset());
64
65            let (scale, _) = $this.scales.storage_and_layout();
66            let scale = match &*scale {
67                candle_core::Storage::Cuda(s) => s,
68                _ => candle_core::bail!("scale must be a cuda tensor"),
69            };
70            let (scale_slice, _scale_guard) = crate::utils::slice_ptr(scale.as_cuda_slice::<$scale_t>()?, $this.scales.layout().start_offset());
71
72            let (zero, _) = $this.zeros.storage_and_layout();
73            let zero = match &*zero {
74                candle_core::Storage::Cuda(s) => s,
75                _ => candle_core::bail!("zero must be a cuda tensor"),
76            };
77            let (zero_slice, _zero_guard) = crate::utils::slice_ptr(zero.as_cuda_slice::<$scale_t>()?, $this.zeros.layout().start_offset());
78
79            let (h, w) = $this.w_q.dims2()?;
80            let num_packed_elems = $pack;
81            let out_shape = Shape::from_dims(&[num_packed_elems * h, w]);
82
83            let out = unsafe { $dev.alloc::<$scale_t>(out_shape.elem_count())? };
84            let (out_ptr, out_guard) = out.device_ptr(out.stream());
85            unsafe {
86                $bit_thing::[< dequantize_ $postfix >](
87                    w_slice as *const $wq_t,
88                    scale_slice as *const $scale_t,
89                    zero_slice as *const $scale_t,
90                    out_ptr as *mut $scale_t,
91                    h as i32,
92                    w as i32,
93                );
94            }
95            drop(out_guard);
96
97            let storage = CudaStorage {
98                slice: CudaStorageSlice::$dtype(out),
99                device: $dev.clone(),
100            };
101            let storage = Storage::Cuda(storage);
102
103            from_storage_no_op(storage, out_shape, false)
104        }
105    }};
106}
107
108#[derive(Debug, Clone, Copy)]
109pub enum HqqAxis {
110    Zero = 0,
111    One = 1,
112}
113
114impl TryFrom<usize> for HqqAxis {
115    type Error = candle_core::Error;
116    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
117        match value {
118            0 => Ok(Self::Zero),
119            1 => Ok(Self::One),
120            other => candle_core::bail!("Unexpected value for HQQ axis {other}"),
121        }
122    }
123}
124
125#[derive(Debug, Clone, Copy)]
126pub enum HqqBits {
127    Eight = 8,
128    Four = 4,
129    Three = 3,
130    Two = 2,
131    One = 1,
132}
133
134impl TryFrom<usize> for HqqBits {
135    type Error = candle_core::Error;
136    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
137        match value {
138            8 => Ok(Self::Eight),
139            4 => Ok(Self::Four),
140            3 => Ok(Self::Three),
141            2 => Ok(Self::Two),
142            1 => Ok(Self::One),
143            other => candle_core::bail!("Unexpected value for HQQ bits {other}"),
144        }
145    }
146}
147
148impl HqqBits {
149    // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/core/bitpack.py#L10
150    pub(crate) fn bitpack_type(&self) -> impl Fn(Tensor) -> Result<Tensor> {
151        match self {
152            Self::Eight => |wq: Tensor| -> Result<Tensor> {
153                #[allow(unused_variables)]
154                let device = wq.device();
155
156                #[cfg(feature = "cuda")]
157                if device.is_cuda() {
158                    // Use CUDA kernel for 8-bit (which is essentially a copy)
159                    let dev = get_cuda_device(&wq)?;
160                    let wq = wq.to_dtype(DType::U8)?;
161                    let (wq_storage, _) = wq.storage_and_layout();
162                    let wq_storage = match &*wq_storage {
163                        Storage::Cuda(s) => s,
164                        _ => candle_core::bail!("Expected CUDA storage"),
165                    };
166
167                    let output_shape = wq.shape().clone();
168                    let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
169
170                    unsafe {
171                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
172                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
173                            wq_storage.as_cuda_slice::<u8>()?,
174                            wq.layout().start_offset(),
175                        );
176
177                        bitpack_ffi::launch_pack_8bit_kernel(
178                            input_ptr as *const u8,
179                            output_ptr as *mut u8,
180                            output_shape.elem_count(),
181                            dev.cuda_stream().cu_stream(),
182                        );
183                        drop(output_guard);
184                    }
185
186                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
187                    let storage = Storage::Cuda(storage);
188                    return Ok(from_storage_no_op(storage, output_shape, false));
189                }
190
191                #[cfg(feature = "metal")]
192                if device.is_metal() {
193                    use candle_core::MetalStorage;
194
195                    let dev = device.as_metal_device()?;
196                    let command_buffer = dev.command_buffer()?;
197                    command_buffer.set_label("hqq_pack_8bit");
198
199                    let (wq_storage, _wq_layout) = wq.storage_and_layout();
200                    let wq_storage = match &*wq_storage {
201                        Storage::Metal(s) => s,
202                        _ => candle_core::bail!("Expected Metal storage"),
203                    };
204
205                    let output_shape = wq.shape().clone();
206                    let output = dev.new_buffer(
207                        output_shape.elem_count(),
208                        DType::U8,
209                        "hqq_pack_8bit_output",
210                    )?;
211
212                    crate::metal_kernels::call_hqq_pack_8bit(
213                        dev.device(),
214                        &command_buffer,
215                        &crate::metal_kernels::Kernels::new(),
216                        wq_storage.buffer(),
217                        &output,
218                        output_shape.elem_count(),
219                    )
220                    .map_err(candle_core::Error::wrap)?;
221
222                    let storage = MetalStorage::new(
223                        output,
224                        dev.clone(),
225                        output_shape.elem_count(),
226                        DType::U8,
227                    );
228                    let storage = Storage::Metal(storage);
229
230                    return Ok(from_storage_no_op(storage, output_shape, false));
231                }
232
233                wq.to_dtype(DType::U8)
234            },
235            Self::Four => |wq_in: Tensor| -> Result<Tensor> {
236                #[allow(unused_variables)]
237                let device = wq_in.device();
238
239                #[cfg(feature = "cuda")]
240                if device.is_cuda() {
241                    // Use CUDA kernel for 4-bit packing
242                    let dev = get_cuda_device(&wq_in)?;
243                    let wq = wq_in.to_dtype(DType::U8)?;
244                    let (wq_storage, _) = wq.storage_and_layout();
245                    let wq_storage = match &*wq_storage {
246                        Storage::Cuda(s) => s,
247                        _ => candle_core::bail!("Expected CUDA storage"),
248                    };
249
250                    let output_height = wq.dims()[0] / 2;
251                    let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
252                    let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
253
254                    unsafe {
255                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
256                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
257                            wq_storage.as_cuda_slice::<u8>()?,
258                            wq.layout().start_offset(),
259                        );
260
261                        bitpack_ffi::launch_pack_4bit_kernel(
262                            input_ptr as *const u8,
263                            output_ptr as *mut u8,
264                            wq.dims()[0],
265                            wq.dims()[1],
266                            dev.cuda_stream().cu_stream(),
267                        );
268                        drop(output_guard);
269                    }
270
271                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
272                    let storage = Storage::Cuda(storage);
273                    return Ok(from_storage_no_op(storage, output_shape, false));
274                }
275
276                #[cfg(feature = "metal")]
277                if device.is_metal() {
278                    use candle_core::MetalStorage;
279
280                    let dev = device.as_metal_device()?;
281                    let command_buffer = dev.command_buffer()?;
282                    command_buffer.set_label("hqq_pack_4bit");
283
284                    let wq = wq_in.to_dtype(DType::U8)?;
285                    let (wq_storage, _wq_layout) = wq.storage_and_layout();
286                    let wq_storage = match &*wq_storage {
287                        Storage::Metal(s) => s,
288                        _ => candle_core::bail!("Expected Metal storage"),
289                    };
290
291                    let output_height = wq.dims()[0] / 2;
292                    let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
293                    let output = dev.new_buffer(
294                        output_shape.elem_count(),
295                        DType::U8,
296                        "hqq_pack_4bit_output",
297                    )?;
298
299                    crate::metal_kernels::call_hqq_pack_4bit(
300                        dev.device(),
301                        &command_buffer,
302                        &crate::metal_kernels::Kernels::new(),
303                        wq_storage.buffer(),
304                        &output,
305                        wq.dims()[0],
306                        wq.dims()[1],
307                    )
308                    .map_err(candle_core::Error::wrap)?;
309
310                    let storage = MetalStorage::new(
311                        output,
312                        dev.clone(),
313                        output_shape.elem_count(),
314                        DType::U8,
315                    );
316                    let storage = Storage::Metal(storage);
317
318                    return Ok(from_storage_no_op(storage, output_shape, false));
319                }
320
321                // CPU fallback
322                let wq = wq_in.to_dtype(DType::U8)?;
323                let step = (wq.dims()[0] as f64 / 2.) as usize;
324
325                let a = wq.narrow(0, 0, step)?;
326                let b = wq.narrow(0, step, step)?;
327                a.leftshift(4)?.bitwise_or(&b)
328            },
329            Self::Two => |wq_in: Tensor| -> Result<Tensor> {
330                #[allow(unused_variables)]
331                let device = wq_in.device();
332
333                #[cfg(feature = "cuda")]
334                if device.is_cuda() {
335                    // Use CUDA kernel for 2-bit packing
336                    let dev = get_cuda_device(&wq_in)?;
337                    let wq = wq_in.to_dtype(DType::U8)?;
338                    let (wq_storage, _) = wq.storage_and_layout();
339                    let wq_storage = match &*wq_storage {
340                        Storage::Cuda(s) => s,
341                        _ => candle_core::bail!("Expected CUDA storage"),
342                    };
343
344                    let output_height = wq.dims()[0] / 4;
345                    let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
346                    let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
347
348                    unsafe {
349                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
350                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
351                            wq_storage.as_cuda_slice::<u8>()?,
352                            wq.layout().start_offset(),
353                        );
354
355                        bitpack_ffi::launch_pack_2bit_kernel(
356                            input_ptr as *const u8,
357                            output_ptr as *mut u8,
358                            wq.dims()[0],
359                            wq.dims()[1],
360                            dev.cuda_stream().cu_stream(),
361                        );
362                        drop(output_guard);
363                    }
364
365                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
366                    let storage = Storage::Cuda(storage);
367                    Ok(from_storage_no_op(storage, output_shape, false))
368                } else {
369                    // CPU fallback
370                    let wq = wq_in.to_dtype(DType::U8)?;
371                    let step = (wq.dims()[0] as f64 / 4.) as usize;
372
373                    let a = wq.narrow(0, 0, step)?;
374                    let b = wq.narrow(0, step, step)?;
375                    let c = wq.narrow(0, step * 2, step)?;
376                    let d = wq.narrow(0, step * 3, step)?;
377
378                    a.leftshift(6)?
379                        .bitwise_or(&b.leftshift(4)?)?
380                        .bitwise_or(&c.leftshift(2)?)?
381                        .bitwise_or(&d)
382                }
383                #[cfg(not(feature = "cuda"))]
384                {
385                    let wq = wq_in.to_dtype(DType::U8)?;
386                    let step = (wq.dims()[0] as f64 / 4.) as usize;
387
388                    let a = wq.narrow(0, 0, step)?;
389                    let b = wq.narrow(0, step, step)?;
390                    let c = wq.narrow(0, step * 2, step)?;
391                    let d = wq.narrow(0, step * 3, step)?;
392
393                    a.leftshift(6)?
394                        .bitwise_or(&b.leftshift(4)?)?
395                        .bitwise_or(&c.leftshift(2)?)?
396                        .bitwise_or(&d)
397                }
398            },
399            Self::Three => |wq_in: Tensor| -> Result<Tensor> {
400                let device = wq_in.device();
401
402                // Pad input to multiple of 10
403                let padded_height = (10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize;
404                let wq = Tensor::zeros((padded_height, wq_in.dims()[1]), DType::U32, device)?;
405                let wq = wq.slice_assign(
406                    &[0..wq_in.dims()[0], 0..wq.dims()[1]],
407                    &wq_in.to_dtype(DType::U32)?,
408                )?;
409
410                #[cfg(feature = "cuda")]
411                if device.is_cuda() {
412                    // Use CUDA kernel for efficient 3-bit packing
413                    let dev = get_cuda_device(&wq)?;
414                    let (wq_storage, _) = wq.storage_and_layout();
415                    let wq_storage = match &*wq_storage {
416                        Storage::Cuda(s) => s,
417                        _ => candle_core::bail!("Expected CUDA storage"),
418                    };
419
420                    let output_height = padded_height / 10;
421                    let output_shape = Shape::from_dims(&[output_height, wq_in.dims()[1]]);
422                    let output = unsafe { dev.alloc::<i32>(output_shape.elem_count())? };
423
424                    unsafe {
425                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
426                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
427                            wq_storage.as_cuda_slice::<u32>()?,
428                            wq.layout().start_offset(),
429                        );
430
431                        bitpack_ffi::launch_pack_3bit_kernel(
432                            input_ptr as *const u32,
433                            output_ptr as *mut i32,
434                            padded_height,
435                            wq_in.dims()[1],
436                            dev.cuda_stream().cu_stream(),
437                        );
438                        drop(output_guard);
439                    }
440
441                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
442                    let storage = Storage::Cuda(storage);
443                    return Ok(from_storage_no_op(storage, output_shape, false));
444                }
445
446                // CPU fallback implementation
447                let wq = if wq.device().is_metal() {
448                    // Metal doesn't support direct U32 to I32 conversion, use CPU as intermediate
449                    let cpu_wq = wq.to_device(&Device::Cpu)?;
450                    cpu_wq.to_dtype(DType::I32)?.to_device(wq.device())?
451                } else {
452                    wq.to_dtype(DType::I32)?
453                };
454                let step = (wq.dims()[0] as f64 / 10.) as usize;
455
456                let a = wq.narrow(0, 0, step)?;
457                let b = wq.narrow(0, step, step)?;
458                let c = wq.narrow(0, step * 2, step)?;
459                let d = wq.narrow(0, step * 3, step)?;
460                let e = wq.narrow(0, step * 4, step)?;
461                let f = wq.narrow(0, step * 5, step)?;
462                let g = wq.narrow(0, step * 6, step)?;
463                let h = wq.narrow(0, step * 7, step)?;
464                let i = wq.narrow(0, step * 8, step)?;
465                let j = wq.narrow(0, step * 9, step)?;
466
467                a.leftshift(27)?
468                    .bitwise_or(&b.leftshift(24)?)?
469                    .bitwise_or(&c.leftshift(21)?)?
470                    .bitwise_or(&d.leftshift(18)?)?
471                    .bitwise_or(&e.leftshift(15)?)?
472                    .bitwise_or(&f.leftshift(12)?)?
473                    .bitwise_or(&g.leftshift(9)?)?
474                    .bitwise_or(&h.leftshift(6)?)?
475                    .bitwise_or(&i.leftshift(3)?)?
476                    .bitwise_or(&j)
477            },
478            Self::One => |wq_in: Tensor| -> Result<Tensor> {
479                #[allow(unused_variables)]
480                let device = wq_in.device();
481
482                #[cfg(feature = "cuda")]
483                if device.is_cuda() {
484                    // Use CUDA kernel for 1-bit packing
485                    let dev = get_cuda_device(&wq_in)?;
486                    let wq = wq_in.to_dtype(DType::U8)?;
487                    let (wq_storage, _) = wq.storage_and_layout();
488                    let wq_storage = match &*wq_storage {
489                        Storage::Cuda(s) => s,
490                        _ => candle_core::bail!("Expected CUDA storage"),
491                    };
492
493                    let output_height = wq.dims()[0] / 8;
494                    let output_shape = Shape::from_dims(&[output_height, wq.dims()[1]]);
495                    let output = unsafe { dev.alloc::<u8>(output_shape.elem_count())? };
496
497                    unsafe {
498                        let (output_ptr, output_guard) = output.device_ptr(output.stream());
499                        let (input_ptr, _input_guard) = crate::utils::slice_ptr(
500                            wq_storage.as_cuda_slice::<u8>()?,
501                            wq.layout().start_offset(),
502                        );
503
504                        bitpack_ffi::launch_pack_1bit_kernel(
505                            input_ptr as *const u8,
506                            output_ptr as *mut u8,
507                            wq.dims()[0],
508                            wq.dims()[1],
509                            dev.cuda_stream().cu_stream(),
510                        );
511                        drop(output_guard);
512                    }
513
514                    let storage = CudaStorage::wrap_cuda_slice(output, dev.clone());
515                    let storage = Storage::Cuda(storage);
516                    Ok(from_storage_no_op(storage, output_shape, false))
517                } else {
518                    // CPU fallback
519                    let wq = wq_in.to_dtype(DType::U8)?;
520                    let step = (wq.dims()[0] as f64 / 8.) as usize;
521
522                    let a = wq.narrow(0, 0, step)?;
523                    let b = wq.narrow(0, step, step)?;
524                    let c = wq.narrow(0, step * 2, step)?;
525                    let d = wq.narrow(0, step * 3, step)?;
526                    let e = wq.narrow(0, step * 4, step)?;
527                    let f = wq.narrow(0, step * 5, step)?;
528                    let g = wq.narrow(0, step * 6, step)?;
529                    let h = wq.narrow(0, step * 7, step)?;
530
531                    a.leftshift(7)?
532                        .bitwise_or(&b.leftshift(6)?)?
533                        .bitwise_or(&c.leftshift(5)?)?
534                        .bitwise_or(&d.leftshift(4)?)?
535                        .bitwise_or(&e.leftshift(3)?)?
536                        .bitwise_or(&f.leftshift(2)?)?
537                        .bitwise_or(&g.leftshift(1)?)?
538                        .bitwise_or(&h)
539                }
540                #[cfg(not(feature = "cuda"))]
541                {
542                    let wq = wq_in.to_dtype(DType::U8)?;
543                    let step = (wq.dims()[0] as f64 / 8.) as usize;
544
545                    let a = wq.narrow(0, 0, step)?;
546                    let b = wq.narrow(0, step, step)?;
547                    let c = wq.narrow(0, step * 2, step)?;
548                    let d = wq.narrow(0, step * 3, step)?;
549                    let e = wq.narrow(0, step * 4, step)?;
550                    let f = wq.narrow(0, step * 5, step)?;
551                    let g = wq.narrow(0, step * 6, step)?;
552                    let h = wq.narrow(0, step * 7, step)?;
553
554                    a.leftshift(7)?
555                        .bitwise_or(&b.leftshift(6)?)?
556                        .bitwise_or(&c.leftshift(5)?)?
557                        .bitwise_or(&d.leftshift(4)?)?
558                        .bitwise_or(&e.leftshift(3)?)?
559                        .bitwise_or(&f.leftshift(2)?)?
560                        .bitwise_or(&g.leftshift(1)?)?
561                        .bitwise_or(&h)
562                }
563            },
564        }
565    }
566}
567
568#[derive(Debug, Clone, Copy)]
569pub struct HqqConfig {
570    pub bits: HqqBits,
571    pub group_size: NonZeroUsize,
572    pub axis: HqqAxis,
573    pub optimization_steps: Option<usize>,
574    pub round_zeros: bool,  // default false
575    pub channel_wise: bool, // default true
576}
577
578#[derive(Debug)]
579pub struct HqqLayer {
580    pub(crate) w_q: Tensor,
581    pub(crate) zeros: Tensor,
582    pub(crate) scales: Tensor,
583    pub(crate) bias: Option<Tensor>,
584    pub(crate) w_shape: Shape,
585    pub(crate) cfg: HqqConfig,
586}
587
588impl HqqLayer {
589    /// Dequantize `self` into a tensor of shape `scales` or `zeros`.
590    #[cfg(not(feature = "cuda"))]
591    fn dequantize(&self) -> Result<Tensor> {
592        use crate::hqq::hqq_op::{Dequant1Bit, Dequant2Bit, Dequant3Bit, Dequant4Bit, Dequant8Bit};
593
594        match (self.scales.dtype(), self.zeros.dtype()) {
595            (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
596            (a, b) => {
597                candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
598            }
599        }
600        if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
601        {
602            candle_core::bail!("All tensors must be contiguous!");
603        }
604        if self.cfg.axis as usize != 0 {
605            candle_core::bail!(
606                "CPU HQQ dequantization requires axis == 0, got {}.",
607                self.cfg.axis as usize
608            );
609        }
610        let (h, w) = self.w_q.dims2()?;
611
612        match self.cfg.bits as usize {
613            8 => self
614                .w_q
615                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant8Bit { h, w })?
616                .reshape(&self.w_shape),
617            4 => self
618                .w_q
619                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant4Bit { h, w })?
620                .reshape(&self.w_shape),
621            3 => self
622                .w_q
623                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant3Bit { h, w })?
624                .reshape(&self.w_shape),
625            2 => self
626                .w_q
627                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant2Bit { h, w })?
628                .reshape(&self.w_shape),
629            1 => self
630                .w_q
631                .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant1Bit { h, w })?
632                .reshape(&self.w_shape),
633            b => candle_core::bail!("Unreachable bits {b}"),
634        }
635    }
636
637    /// Dequantize `self` into a tensor of shape `scales` or `zeros`.
638    #[cfg(feature = "cuda")]
639    fn dequantize(&self) -> Result<Tensor> {
640        match (self.scales.dtype(), self.zeros.dtype()) {
641            (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (),
642            (a, b) => {
643                candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).")
644            }
645        }
646        if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous())
647        {
648            candle_core::bail!("All tensors must be contiguous!");
649        }
650        if self.cfg.axis as usize != 0 {
651            candle_core::bail!(
652                "CUDA HQQ dequantization requires axis == 0, got {}.",
653                self.cfg.axis as usize
654            );
655        }
656        let dev = get_cuda_device(&self.w_q)?;
657
658        let inner = match (self.cfg.bits as usize, self.scales.dtype()) {
659            // 8 bits
660            (8, DType::F32) => {
661                dequant_for_dtype!(
662                    self,
663                    w = u8,
664                    sz = f32,
665                    F32,
666                    pack = 1,
667                    dev,
668                    eight_bit,
669                    8bit_u8_kernel_f32
670                )
671            }
672            (8, DType::F16) => {
673                dequant_for_dtype!(
674                    self,
675                    w = u8,
676                    sz = f16,
677                    F16,
678                    pack = 1,
679                    dev,
680                    eight_bit,
681                    8bit_u8_kernel_f16
682                )
683            }
684            (8, DType::BF16) => {
685                dequant_for_dtype!(
686                    self,
687                    w = u8,
688                    sz = bf16,
689                    BF16,
690                    pack = 1,
691                    dev,
692                    eight_bit,
693                    8bit_u8_kernel_bf16
694                )
695            }
696
697            // 4 bits
698            (4, DType::F32) => {
699                dequant_for_dtype!(
700                    self,
701                    w = u8,
702                    sz = f32,
703                    F32,
704                    pack = 2,
705                    dev,
706                    four_bit,
707                    4bit_u8_kernel_f32
708                )
709            }
710            (4, DType::F16) => {
711                dequant_for_dtype!(
712                    self,
713                    w = u8,
714                    sz = f16,
715                    F16,
716                    pack = 2,
717                    dev,
718                    four_bit,
719                    4bit_u8_kernel_f16
720                )
721            }
722            (4, DType::BF16) => {
723                dequant_for_dtype!(
724                    self,
725                    w = u8,
726                    sz = bf16,
727                    BF16,
728                    pack = 2,
729                    dev,
730                    four_bit,
731                    4bit_u8_kernel_bf16
732                )
733            }
734
735            // 3 bits
736            // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/kernels/hqq_aten_cuda.cpp#L42-L45
737            (3, DType::F32) => {
738                let res = dequant_for_dtype!(
739                    self,
740                    w = i32,
741                    sz = f32,
742                    F32,
743                    pack = 10,
744                    dev,
745                    three_bit,
746                    3bit_32_kernel_f32
747                );
748                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
749            }
750            (3, DType::F16) => {
751                let res = dequant_for_dtype!(
752                    self,
753                    w = i32,
754                    sz = f16,
755                    F16,
756                    pack = 10,
757                    dev,
758                    three_bit,
759                    3bit_32_kernel_f16
760                );
761                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
762            }
763            (3, DType::BF16) => {
764                let res = dequant_for_dtype!(
765                    self,
766                    w = i32,
767                    sz = bf16,
768                    BF16,
769                    pack = 10,
770                    dev,
771                    three_bit,
772                    3bit_32_kernel_bf16
773                );
774                res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())?
775            }
776
777            // 2 bits
778            (2, DType::F32) => {
779                dequant_for_dtype!(
780                    self,
781                    w = u8,
782                    sz = f32,
783                    F32,
784                    pack = 4,
785                    dev,
786                    two_bit,
787                    2bit_u8_kernel_f32
788                )
789            }
790            (2, DType::F16) => {
791                dequant_for_dtype!(
792                    self,
793                    w = u8,
794                    sz = f16,
795                    F16,
796                    pack = 4,
797                    dev,
798                    two_bit,
799                    2bit_u8_kernel_f16
800                )
801            }
802            (2, DType::BF16) => {
803                dequant_for_dtype!(
804                    self,
805                    w = u8,
806                    sz = bf16,
807                    BF16,
808                    pack = 4,
809                    dev,
810                    two_bit,
811                    2bit_u8_kernel_bf16
812                )
813            }
814
815            // 1 bit
816            (1, DType::F32) => {
817                dequant_for_dtype!(
818                    self,
819                    w = u8,
820                    sz = f32,
821                    F32,
822                    pack = 8,
823                    dev,
824                    one_bit,
825                    1bit_u8_kernel_f32
826                )
827            }
828            (1, DType::F16) => {
829                dequant_for_dtype!(
830                    self,
831                    w = u8,
832                    sz = f16,
833                    F16,
834                    pack = 8,
835                    dev,
836                    one_bit,
837                    1bit_u8_kernel_f16
838                )
839            }
840            (1, DType::BF16) => {
841                dequant_for_dtype!(
842                    self,
843                    w = u8,
844                    sz = bf16,
845                    BF16,
846                    pack = 8,
847                    dev,
848                    one_bit,
849                    1bit_u8_kernel_bf16
850                )
851            }
852            (bits, dtype) => candle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}"),
853        };
854        inner.reshape(&self.w_shape)
855    }
856
857    fn dequantize_matmul(&self, xs: &Tensor) -> Result<Tensor> {
858        let w = self.dequantize()?;
859        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
860        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
861            w,
862            self.bias.clone(),
863        )))?;
864        unquant.forward(xs)
865    }
866
867    pub fn with_bias(mut self, bias: Tensor) -> Self {
868        self.bias = Some(bias);
869        self
870    }
871}
872
873impl QuantMethod for HqqLayer {
874    fn new(method: QuantMethodConfig) -> Result<Self>
875    where
876        Self: Sized,
877    {
878        match method {
879            QuantMethodConfig::Gguf { .. }
880            | QuantMethodConfig::Unquantized(_)
881            | QuantMethodConfig::GptqAwq { .. }
882            | QuantMethodConfig::Dummy
883            | QuantMethodConfig::FP8 { .. }
884            | QuantMethodConfig::Bnb { .. }
885            | QuantMethodConfig::BlockwiseFP8 { .. }
886            | QuantMethodConfig::Afq { .. } => {
887                unreachable!()
888            }
889            QuantMethodConfig::Hqq {
890                tensor,
891                bits,
892                group_size,
893                axis,
894                optimization_steps,
895                round_zeros,
896                channel_wise,
897                bias,
898            } => {
899                let cfg = HqqConfig {
900                    bits,
901                    group_size,
902                    axis,
903                    optimization_steps,
904                    round_zeros: round_zeros.unwrap_or(false),
905                    channel_wise: channel_wise.unwrap_or(true),
906                };
907
908                let this = Self::quantize(&tensor, tensor.device(), cfg)?;
909                if let Some(bias) = bias {
910                    Ok(this.with_bias(bias))
911                } else {
912                    Ok(this)
913                }
914            }
915        }
916    }
917
918    fn dequantize_w(&self) -> Result<Tensor> {
919        self.dequantize()
920    }
921
922    fn forward(&self, a: &Tensor) -> Result<Tensor> {
923        /*
924        if self.cfg.force_dequantize {
925            self.dequantize_matmul(a)
926        } else {
927            todo!()
928        } */
929        self.dequantize_matmul(a)
930    }
931
932    fn quantized_act_type(&self) -> Option<DType> {
933        Some(self.scales.dtype())
934    }
935
936    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
937        candle_core::bail!("HQQ quantization does not support adding weight delta.")
938    }
939
940    fn dtype_and_device(&self) -> (DType, Device) {
941        (self.scales.dtype(), self.scales.device().clone())
942    }
943
944    fn apply_isq(
945        self: Arc<Self>,
946        dtype: Option<IsqType>,
947        device: Device,
948        n_quantized: &AtomicUsize,
949        imatrix_weight: Option<Vec<f32>>,
950        guard: QuantizeOntoGuard,
951    ) -> Result<Arc<dyn QuantMethod>> {
952        let _acquired_quantize_guard = guard.acquire(&device);
953        if imatrix_weight.is_some() {
954            // TODO just warn?
955            candle_core::bail!("HQQ does not support imatrix.");
956        }
957
958        n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
959        let bits = match dtype {
960            Some(IsqType::HQQ8) => HqqBits::Eight,
961            Some(IsqType::HQQ4) => HqqBits::Four,
962            // Some(IsqType::HQQ3) => HqqBits::Three,
963            // Some(IsqType::HQQ2) => HqqBits::Two,
964            // Some(IsqType::HQQ1) => HqqBits::One,
965            _ => candle_core::bail!("Expected a HQQ ISQ type."),
966        };
967        let cfg = HqqConfig {
968            bits,
969            group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
970            axis: HqqAxis::Zero,
971            optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
972            round_zeros: false,
973            channel_wise: true,
974        };
975        let dequant = self.dequantize()?;
976        let res = Self::quantize(&dequant, &device, cfg)?;
977        if let Some(ref bias) = self.bias {
978            let bias = bias
979                .to_device(&device)?
980                .to_dtype(res.dtype_and_device().0)?;
981            Ok(Arc::new(res.with_bias(bias)))
982        } else {
983            Ok(Arc::new(res))
984        }
985    }
986}
987
988// Serialization structure:
989//
990// -----------------------
991// UQFF version, u32, little endian
992// -----------------------
993// ISQ type (2 for hqq), u8, little endian
994// -----------------------
995// Whether bias data is included, u8 boolean
996// -----------------------
997// Quantized weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
998// -----------------------
999// Quantized scale tensor data generated by `serialize_tensor`. Refer to its docs for layout.
1000// -----------------------
1001// Quantized zeroes tensor data generated by `serialize_tensor`. Refer to its docs for layout.
1002// -----------------------
1003// Weight (after dequant) shape dims, u32, little endian
1004// -----------------------
1005// ...
1006// Array (in original order): Weight (after dequant) shape dims, u32, little endian
1007// ...
1008// -----------------------
1009// Cfg bits, u8, little endian
1010// -----------------------
1011// Cfg group size, u32, little endian
1012// -----------------------
1013// Cfg axis, u8, little endian
1014// -----------------------
1015// Cfg optimization steps, u32, little endian
1016// -----------------------
1017// Cfg round_zeros, boolean u8, little endian
1018// -----------------------
1019// Cfg channel_wise, boolean u8, little endian
1020// -----------------------
1021// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
1022// -----------------------
1023
1024impl QuantizedSerde for HqqLayer {
1025    fn isq_serde_supported(&self) -> bool {
1026        true
1027    }
1028    fn name(&self) -> &'static str {
1029        "hqq"
1030    }
1031    fn serialize(&self) -> Result<Cow<[u8]>> {
1032        self.serialize_with_bias(self.bias.clone())
1033    }
1034    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<[u8]>> {
1035        let mut buffer = Vec::new();
1036
1037        // Version is always first!
1038        buffer.extend(&UQFF_VERSION.to_le_bytes());
1039
1040        // ISQ type for hqq is 2
1041        buffer.push(QuantizedSerdeType::Hqq as u8);
1042
1043        // Has bias
1044        buffer.push(bias.is_some() as u8);
1045
1046        serialize_tensor(&mut buffer, &self.w_q)?;
1047        serialize_tensor(&mut buffer, &self.scales)?;
1048        serialize_tensor(&mut buffer, &self.zeros)?;
1049
1050        let w_shape = self.w_shape.dims();
1051        let shape_len = w_shape.len();
1052        if shape_len > u32::MAX as usize {
1053            candle_core::bail!(
1054                "Weight tensor has too many dimensions for UQFF format: {} exceeds u32::MAX",
1055                shape_len
1056            );
1057        }
1058        buffer.extend((shape_len as u32).to_le_bytes());
1059        for dim in w_shape {
1060            if *dim > u32::MAX as usize {
1061                candle_core::bail!(
1062                    "Weight tensor dimension too large for UQFF format: {} exceeds u32::MAX",
1063                    dim
1064                );
1065            }
1066            buffer.extend((*dim as u32).to_le_bytes());
1067        }
1068
1069        // Config
1070        buffer.push(self.cfg.bits as u8);
1071        let group_size = <NonZeroUsize as Into<usize>>::into(self.cfg.group_size);
1072        if group_size > u32::MAX as usize {
1073            candle_core::bail!(
1074                "HQQ group size too large for UQFF format: {} exceeds u32::MAX",
1075                group_size
1076            );
1077        }
1078        buffer.extend(&(group_size as u32).to_le_bytes());
1079        buffer.push(self.cfg.axis as u8);
1080        // NOTE: using 0 as a sentinel for None. This means legitimate 0 values cannot be distinguished from None.
1081        // This is acceptable because 0 optimization steps would be functionally equivalent to None.
1082        let opt_steps = self.cfg.optimization_steps.unwrap_or(0);
1083        if opt_steps > u32::MAX as usize {
1084            candle_core::bail!(
1085                "HQQ optimization steps too large for UQFF format: {} exceeds u32::MAX",
1086                opt_steps
1087            );
1088        }
1089        buffer.extend(&(opt_steps as u32).to_le_bytes());
1090        buffer.push(self.cfg.round_zeros as u8);
1091        buffer.push(self.cfg.channel_wise as u8);
1092
1093        if let Some(bias) = &bias {
1094            // Bias
1095            serialize_tensor(&mut buffer, bias)?;
1096        }
1097
1098        Ok(Cow::from(buffer))
1099    }
1100
1101    fn deserialize(
1102        data: Cow<[u8]>,
1103        device: &Device,
1104        _comm: &Arc<crate::Comm>,
1105        guard: QuantizeOntoGuard,
1106    ) -> Result<Arc<dyn QuantMethod>>
1107    where
1108        Self: Sized,
1109    {
1110        let mut buffer = Cursor::new(data);
1111
1112        let version = buffer.read_u32::<LittleEndian>()?;
1113        if let Err(e) = version_is_compatible(version) {
1114            return Err(candle_core::Error::wrap(e));
1115        }
1116
1117        let isq_type = buffer.read_u8()? as usize;
1118        if isq_type != QuantizedSerdeType::Hqq as usize {
1119            candle_core::bail!(
1120                "ISQ type ({isq_type}) doesn't match expected type {}",
1121                QuantizedSerdeType::Hqq as usize
1122            );
1123        }
1124
1125        let has_bias = buffer.read_u8()? != 0;
1126
1127        let _acquired_load_guard = guard.acquire(device);
1128        let w_q = deserialize_tensor(&mut buffer, device)?;
1129        let scales = deserialize_tensor(&mut buffer, device)?;
1130        let zeros = deserialize_tensor(&mut buffer, device)?;
1131
1132        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
1133
1134        let mut dims = Vec::with_capacity(n_dims);
1135        for _ in 0..n_dims {
1136            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
1137        }
1138        let w_shape = Shape::from_dims(&dims);
1139
1140        // TODO: keep this in sync with get_isq_type_from_uqff!
1141        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
1142        let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
1143        let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
1144        let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
1145            0 => None,
1146            other => Some(other),
1147        };
1148        let round_zeros = buffer.read_u8()? != 0;
1149        let channel_wise = buffer.read_u8()? != 0;
1150
1151        let cfg = HqqConfig {
1152            bits,
1153            group_size,
1154            axis,
1155            optimization_steps,
1156            round_zeros,
1157            channel_wise,
1158        };
1159
1160        let b = if has_bias {
1161            Some(deserialize_tensor(&mut buffer, device)?)
1162        } else {
1163            None
1164        };
1165
1166        Ok(Arc::new(Self {
1167            w_q,
1168            zeros,
1169            scales,
1170            bias: b,
1171            w_shape,
1172            cfg,
1173        }))
1174    }
1175    fn deserialize_ext_bias(
1176        data: Cow<[u8]>,
1177        device: &Device,
1178        guard: QuantizeOntoGuard,
1179    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
1180    where
1181        Self: Sized,
1182    {
1183        let mut buffer = Cursor::new(data);
1184
1185        let version = buffer.read_u32::<LittleEndian>()?;
1186        if let Err(e) = version_is_compatible(version) {
1187            return Err(candle_core::Error::wrap(e));
1188        }
1189
1190        let isq_type = buffer.read_u8()? as usize;
1191        if isq_type != QuantizedSerdeType::Hqq as usize {
1192            candle_core::bail!(
1193                "ISQ type ({isq_type}) doesn't match expected type {}",
1194                QuantizedSerdeType::Hqq as usize
1195            );
1196        }
1197
1198        let has_bias = buffer.read_u8()? != 0;
1199
1200        let _acquired_load_guard = guard.acquire(device);
1201        let w_q = deserialize_tensor(&mut buffer, device)?;
1202        let scales = deserialize_tensor(&mut buffer, device)?;
1203        let zeros = deserialize_tensor(&mut buffer, device)?;
1204
1205        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
1206
1207        let mut dims = Vec::with_capacity(n_dims);
1208        for _ in 0..n_dims {
1209            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
1210        }
1211        let w_shape = Shape::from_dims(&dims);
1212
1213        // TODO: keep this in sync with get_isq_type_from_uqff!
1214        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
1215        let group_size = NonZeroUsize::try_from(buffer.read_u32::<LittleEndian>()? as usize)?;
1216        let axis = HqqAxis::try_from(buffer.read_u8()? as usize)?;
1217        let optimization_steps = match buffer.read_u32::<LittleEndian>()? as usize {
1218            0 => None,
1219            other => Some(other),
1220        };
1221        let round_zeros = buffer.read_u8()? != 0;
1222        let channel_wise = buffer.read_u8()? != 0;
1223
1224        let cfg = HqqConfig {
1225            bits,
1226            group_size,
1227            axis,
1228            optimization_steps,
1229            round_zeros,
1230            channel_wise,
1231        };
1232
1233        let b = if has_bias {
1234            Some(deserialize_tensor(&mut buffer, device)?)
1235        } else {
1236            None
1237        };
1238
1239        Ok((
1240            Arc::new(Self {
1241                w_q,
1242                zeros,
1243                scales,
1244                bias: None,
1245                w_shape,
1246                cfg,
1247            }),
1248            b,
1249        ))
1250    }
1251}
1252
1253impl HqqLayer {
1254    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
1255        let mut buffer = Cursor::new(data);
1256
1257        let version = buffer.read_u32::<LittleEndian>()?;
1258        if let Err(e) = version_is_compatible(version) {
1259            return Err(candle_core::Error::wrap(e));
1260        }
1261
1262        let isq_type = buffer.read_u8()? as usize;
1263        if isq_type != QuantizedSerdeType::Hqq as usize {
1264            candle_core::bail!(
1265                "ISQ type ({isq_type}) doesn't match expected type {}",
1266                QuantizedSerdeType::Hqq as usize
1267            );
1268        }
1269
1270        let _has_bias = buffer.read_u8()? != 0;
1271
1272        fake_deserialize_tensor(&mut buffer)?;
1273        fake_deserialize_tensor(&mut buffer)?;
1274        fake_deserialize_tensor(&mut buffer)?;
1275
1276        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
1277
1278        let mut dims = Vec::with_capacity(n_dims);
1279        for _ in 0..n_dims {
1280            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
1281        }
1282        let _w_shape = Shape::from_dims(&dims);
1283
1284        // TODO: keep this in sync with get_isq_type_from_uqff!
1285        let bits = HqqBits::try_from(buffer.read_u8()? as usize)?;
1286
1287        match bits {
1288            HqqBits::Eight => Ok(IsqType::HQQ8),
1289            HqqBits::Four => Ok(IsqType::HQQ4),
1290            HqqBits::One | HqqBits::Two | HqqBits::Three => {
1291                candle_core::bail!("cannot convert hqq bits to isq type")
1292            }
1293        }
1294    }
1295}