mistralrs_quant/hqq/
hqq_op.rs

1#[cfg(feature = "metal")]
2use candle_core::{backend::BackendStorage, DType};
3use candle_core::{CpuStorage, CustomOp3, Layout, Result, Shape, WithDType};
4
5/*
6 8 bit
7*/
8pub(crate) struct Dequant8Bit {
9    pub(crate) h: usize,
10    pub(crate) w: usize,
11}
12
13impl Dequant8Bit {
14    fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
15        let mut out = vec![T::default(); w.len()];
16        for (i, w) in w.iter().enumerate() {
17            let j = i % self.w;
18            out[i] = (T::from_f64(*w as f64) - z[j]) * s[j];
19        }
20        out
21    }
22}
23
24impl CustomOp3 for Dequant8Bit {
25    fn name(&self) -> &'static str {
26        "dequant-hqq-8bit"
27    }
28    fn cpu_fwd(
29        &self,
30        w: &CpuStorage,
31        l_w: &Layout,
32        s: &CpuStorage,
33        l_s: &Layout,
34        z: &CpuStorage,
35        l_z: &Layout,
36    ) -> Result<(CpuStorage, Shape)> {
37        let CpuStorage::U8(w_slice) = w else {
38            candle_core::bail!("Weight must be u8, HQQ dequant 8-bit");
39        };
40        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
41            candle_core::bail!("All inputs must be contiguous");
42        }
43        match (s, z) {
44            (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
45                CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
46                Shape::from_dims(&[self.h, self.w]),
47            )),
48            (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
49                CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
50                Shape::from_dims(&[self.h, self.w]),
51            )),
52            (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
53                CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
54                Shape::from_dims(&[self.h, self.w]),
55            )),
56            (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
57        }
58    }
59    #[cfg(feature = "metal")]
60    fn metal_fwd(
61        &self,
62        w: &candle_core::MetalStorage,
63        l_w: &Layout,
64        s: &candle_core::MetalStorage,
65        l_s: &Layout,
66        z: &candle_core::MetalStorage,
67        l_z: &Layout,
68    ) -> Result<(candle_core::MetalStorage, Shape)> {
69        if w.dtype() != DType::U8 {
70            candle_core::bail!("Weight must be u8, HQQ dequant 8-bit");
71        };
72        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
73            candle_core::bail!("All inputs must be contiguous");
74        }
75
76        let command_buffer = w.device().command_buffer()?;
77        command_buffer.set_label("dequant-8bit");
78
79        let device = w.device();
80
81        let out_shape = Shape::from_dims(&[self.h, self.w]);
82
83        let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-8bit")?;
84
85        crate::metal_kernels::call_dequant_8bit(
86            device.device(),
87            &command_buffer,
88            &crate::metal_kernels::Kernels::new(),
89            s.dtype(),
90            w.buffer(),
91            s.buffer(),
92            z.buffer(),
93            self.h as u32,
94            self.w as u32,
95            &output,
96        )
97        .map_err(candle_core::Error::wrap)?;
98
99        let newstorage = candle_core::MetalStorage::new(
100            output,
101            device.clone(),
102            out_shape.elem_count(),
103            s.dtype(),
104        );
105        Ok((newstorage, out_shape))
106    }
107}
108
109/*
110 4 bit
111*/
112pub(crate) struct Dequant4Bit {
113    pub(crate) h: usize,
114    pub(crate) w: usize,
115}
116
117impl Dequant4Bit {
118    fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
119        let output_size = w.len() * 2;
120        let mut out = vec![T::default(); output_size];
121        for (i, w) in w.iter().enumerate() {
122            let j = i % self.w;
123            let nrows = self.h * self.w;
124            out[i] = (T::from_f64(((*w & 0xF0) >> 4) as f64) - z[j]) * s[j];
125            out[i + nrows] = (T::from_f64((*w & 0x0F) as f64) - z[j]) * s[j];
126        }
127        out
128    }
129}
130
131impl CustomOp3 for Dequant4Bit {
132    fn name(&self) -> &'static str {
133        "dequant-hqq-4bit"
134    }
135    fn cpu_fwd(
136        &self,
137        w: &CpuStorage,
138        l_w: &Layout,
139        s: &CpuStorage,
140        l_s: &Layout,
141        z: &CpuStorage,
142        l_z: &Layout,
143    ) -> Result<(CpuStorage, Shape)> {
144        const PACK_FACTOR: usize = 2;
145
146        let CpuStorage::U8(w_slice) = w else {
147            candle_core::bail!("Weight must be u8, HQQ dequant 4-bit");
148        };
149        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
150            candle_core::bail!("All inputs must be contiguous");
151        }
152        match (s, z) {
153            (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
154                CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
155                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
156            )),
157            (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
158                CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
159                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
160            )),
161            (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
162                CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
163                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
164            )),
165            (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
166        }
167    }
168    #[cfg(feature = "metal")]
169    fn metal_fwd(
170        &self,
171        w: &candle_core::MetalStorage,
172        l_w: &Layout,
173        s: &candle_core::MetalStorage,
174        l_s: &Layout,
175        z: &candle_core::MetalStorage,
176        l_z: &Layout,
177    ) -> Result<(candle_core::MetalStorage, Shape)> {
178        const PACK_FACTOR: usize = 2;
179
180        if w.dtype() != DType::U8 {
181            candle_core::bail!("Weight must be u8, HQQ dequant 4-bit");
182        };
183        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
184            candle_core::bail!("All inputs must be contiguous");
185        }
186
187        let command_buffer = w.device().command_buffer()?;
188        command_buffer.set_label("dequant-4bit");
189
190        let device = w.device();
191
192        let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);
193
194        let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-4bit")?;
195
196        crate::metal_kernels::call_dequant_4bit(
197            device.device(),
198            &command_buffer,
199            &crate::metal_kernels::Kernels::new(),
200            s.dtype(),
201            w.buffer(),
202            s.buffer(),
203            z.buffer(),
204            self.h as u32,
205            self.w as u32,
206            &output,
207        )
208        .map_err(candle_core::Error::wrap)?;
209
210        let newstorage = candle_core::MetalStorage::new(
211            output,
212            device.clone(),
213            out_shape.elem_count(),
214            s.dtype(),
215        );
216        Ok((newstorage, out_shape))
217    }
218}
219
220/*
221 2 bit
222*/
223pub(crate) struct Dequant2Bit {
224    pub(crate) h: usize,
225    pub(crate) w: usize,
226}
227
228impl Dequant2Bit {
229    fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
230        let output_size = w.len() * 4;
231        let mut out = vec![T::default(); output_size];
232        for (i, w) in w.iter().enumerate() {
233            let j = i % self.w;
234            let nrows = self.h * self.w;
235            out[i] = (T::from_f64(((*w & 0xC0) >> 6) as f64) - z[j]) * s[j];
236            out[i + nrows] = (T::from_f64(((*w & 0x30) >> 4) as f64) - z[j]) * s[j];
237            out[i + nrows * 2] = (T::from_f64(((*w & 0x0C) >> 2) as f64) - z[j]) * s[j];
238            out[i + nrows * 3] = (T::from_f64((*w & 0x03) as f64) - z[j]) * s[j];
239        }
240        out
241    }
242}
243
244impl CustomOp3 for Dequant2Bit {
245    fn name(&self) -> &'static str {
246        "dequant-hqq-2bit"
247    }
248    fn cpu_fwd(
249        &self,
250        w: &CpuStorage,
251        l_w: &Layout,
252        s: &CpuStorage,
253        l_s: &Layout,
254        z: &CpuStorage,
255        l_z: &Layout,
256    ) -> Result<(CpuStorage, Shape)> {
257        const PACK_FACTOR: usize = 4;
258
259        let CpuStorage::U8(w_slice) = w else {
260            candle_core::bail!("Weight must be u8, HQQ dequant 2-bit");
261        };
262        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
263            candle_core::bail!("All inputs must be contiguous");
264        }
265        match (s, z) {
266            (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
267                CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
268                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
269            )),
270            (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
271                CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
272                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
273            )),
274            (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
275                CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
276                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
277            )),
278            (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
279        }
280    }
281    #[cfg(feature = "metal")]
282    fn metal_fwd(
283        &self,
284        w: &candle_core::MetalStorage,
285        l_w: &Layout,
286        s: &candle_core::MetalStorage,
287        l_s: &Layout,
288        z: &candle_core::MetalStorage,
289        l_z: &Layout,
290    ) -> Result<(candle_core::MetalStorage, Shape)> {
291        const PACK_FACTOR: usize = 4;
292
293        if w.dtype() != DType::U8 {
294            candle_core::bail!("Weight must be u8, HQQ dequant 2-bit");
295        };
296        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
297            candle_core::bail!("All inputs must be contiguous");
298        }
299
300        let command_buffer = w.device().command_buffer()?;
301        command_buffer.set_label("dequant-2bit");
302
303        let device = w.device();
304
305        let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);
306
307        let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-2bit")?;
308
309        crate::metal_kernels::call_dequant_2bit(
310            device.device(),
311            &command_buffer,
312            &crate::metal_kernels::Kernels::new(),
313            s.dtype(),
314            w.buffer(),
315            s.buffer(),
316            z.buffer(),
317            self.h as u32,
318            self.w as u32,
319            &output,
320        )
321        .map_err(candle_core::Error::wrap)?;
322
323        let newstorage = candle_core::MetalStorage::new(
324            output,
325            device.clone(),
326            out_shape.elem_count(),
327            s.dtype(),
328        );
329        Ok((newstorage, out_shape))
330    }
331}
332
333/*
334 1 bit
335*/
336pub(crate) struct Dequant1Bit {
337    pub(crate) h: usize,
338    pub(crate) w: usize,
339}
340
341impl Dequant1Bit {
342    fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
343        let output_size = w.len() * 8;
344        let mut out = vec![T::default(); output_size];
345        for (i, w) in w.iter().enumerate() {
346            let j = i % self.w;
347            let nrows = self.h * self.w;
348            out[i] = (T::from_f64(((*w & 0x80) >> 7) as f64) - z[j]) * s[j];
349            out[i + nrows] = (T::from_f64(((*w & 0x40) >> 6) as f64) - z[j]) * s[j];
350            out[i + nrows * 2] = (T::from_f64(((*w & 0x20) >> 5) as f64) - z[j]) * s[j];
351            out[i + nrows * 3] = (T::from_f64(((*w & 0x10) >> 4) as f64) - z[j]) * s[j];
352            out[i + nrows * 4] = (T::from_f64(((*w & 0x08) >> 3) as f64) - z[j]) * s[j];
353            out[i + nrows * 5] = (T::from_f64(((*w & 0x04) >> 2) as f64) - z[j]) * s[j];
354            out[i + nrows * 6] = (T::from_f64(((*w & 0x02) >> 1) as f64) - z[j]) * s[j];
355            out[i + nrows * 7] = (T::from_f64((*w & 0x01) as f64) - z[j]) * s[j];
356        }
357        out
358    }
359}
360
361impl CustomOp3 for Dequant1Bit {
362    fn name(&self) -> &'static str {
363        "dequant-hqq-1bit"
364    }
365    fn cpu_fwd(
366        &self,
367        w: &CpuStorage,
368        l_w: &Layout,
369        s: &CpuStorage,
370        l_s: &Layout,
371        z: &CpuStorage,
372        l_z: &Layout,
373    ) -> Result<(CpuStorage, Shape)> {
374        const PACK_FACTOR: usize = 8;
375
376        let CpuStorage::U8(w_slice) = w else {
377            candle_core::bail!("Weight must be u8, HQQ dequant 1-bit");
378        };
379        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
380            candle_core::bail!("All inputs must be contiguous");
381        }
382        match (s, z) {
383            (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
384                CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
385                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
386            )),
387            (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
388                CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
389                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
390            )),
391            (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
392                CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
393                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
394            )),
395            (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
396        }
397    }
398    #[cfg(feature = "metal")]
399    fn metal_fwd(
400        &self,
401        w: &candle_core::MetalStorage,
402        l_w: &Layout,
403        s: &candle_core::MetalStorage,
404        l_s: &Layout,
405        z: &candle_core::MetalStorage,
406        l_z: &Layout,
407    ) -> Result<(candle_core::MetalStorage, Shape)> {
408        const PACK_FACTOR: usize = 8;
409
410        if w.dtype() != DType::U8 {
411            candle_core::bail!("Weight must be u8, HQQ dequant 1-bit");
412        };
413        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
414            candle_core::bail!("All inputs must be contiguous");
415        }
416
417        let command_buffer = w.device().command_buffer()?;
418        command_buffer.set_label("dequant-1bit");
419
420        let device = w.device();
421
422        let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);
423
424        let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-1bit")?;
425
426        crate::metal_kernels::call_dequant_1bit(
427            device.device(),
428            &command_buffer,
429            &crate::metal_kernels::Kernels::new(),
430            s.dtype(),
431            w.buffer(),
432            s.buffer(),
433            z.buffer(),
434            self.h as u32,
435            self.w as u32,
436            &output,
437        )
438        .map_err(candle_core::Error::wrap)?;
439
440        let newstorage = candle_core::MetalStorage::new(
441            output,
442            device.clone(),
443            out_shape.elem_count(),
444            s.dtype(),
445        );
446        Ok((newstorage, out_shape))
447    }
448}
449
450/*
451 3 bit
452*/
453pub(crate) struct Dequant3Bit {
454    pub(crate) h: usize,
455    pub(crate) w: usize,
456}
457
458impl Dequant3Bit {
459    fn dequantize<T: WithDType + Default>(&self, w: &[i32], s: &[T], z: &[T]) -> Vec<T> {
460        let output_size = w.len() * 10;
461        let mut out = vec![T::default(); output_size];
462        for (i, w) in w.iter().enumerate() {
463            let j = i % self.w;
464            let nrows = self.h * self.w;
465            out[i] = (T::from_f64(((*w & 0x38000000) >> 27) as f64) - z[j]) * s[j];
466            out[i + nrows] = (T::from_f64(((*w & 0x07000000) >> 24) as f64) - z[j]) * s[j];
467            out[i + nrows * 2] = (T::from_f64(((*w & 0x00E00000) >> 21) as f64) - z[j]) * s[j];
468            out[i + nrows * 3] = (T::from_f64(((*w & 0x001C0000) >> 18) as f64) - z[j]) * s[j];
469            out[i + nrows * 4] = (T::from_f64(((*w & 0x00038000) >> 15) as f64) - z[j]) * s[j];
470            out[i + nrows * 5] = (T::from_f64(((*w & 0x00007000) >> 12) as f64) - z[j]) * s[j];
471            out[i + nrows * 6] = (T::from_f64(((*w & 0x00000E00) >> 9) as f64) - z[j]) * s[j];
472            out[i + nrows * 7] = (T::from_f64(((*w & 0x000001C0) >> 6) as f64) - z[j]) * s[j];
473            out[i + nrows * 8] = (T::from_f64(((*w & 0x00000038) >> 3) as f64) - z[j]) * s[j];
474            out[i + nrows * 9] = (T::from_f64((*w & 0x00000007) as f64) - z[j]) * s[j];
475        }
476        out
477    }
478}
479
480impl CustomOp3 for Dequant3Bit {
481    fn name(&self) -> &'static str {
482        "dequant-hqq-3bit"
483    }
484    fn cpu_fwd(
485        &self,
486        w: &CpuStorage,
487        l_w: &Layout,
488        s: &CpuStorage,
489        l_s: &Layout,
490        z: &CpuStorage,
491        l_z: &Layout,
492    ) -> Result<(CpuStorage, Shape)> {
493        const PACK_FACTOR: usize = 10;
494
495        let CpuStorage::I32(w_slice) = w else {
496            candle_core::bail!("Weight must be i32, HQQ dequant 3-bit");
497        };
498        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
499            candle_core::bail!("All inputs must be contiguous");
500        }
501        match (s, z) {
502            (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
503                CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
504                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
505            )),
506            (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
507                CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
508                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
509            )),
510            (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
511                CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
512                Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
513            )),
514            (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
515        }
516    }
517    #[cfg(feature = "metal")]
518    fn metal_fwd(
519        &self,
520        w: &candle_core::MetalStorage,
521        l_w: &Layout,
522        s: &candle_core::MetalStorage,
523        l_s: &Layout,
524        z: &candle_core::MetalStorage,
525        l_z: &Layout,
526    ) -> Result<(candle_core::MetalStorage, Shape)> {
527        const PACK_FACTOR: usize = 10;
528
529        if w.dtype() != DType::I32 {
530            candle_core::bail!("Weight must be i32, HQQ dequant 3-bit");
531        };
532        if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
533            candle_core::bail!("All inputs must be contiguous");
534        }
535
536        let command_buffer = w.device().command_buffer()?;
537        command_buffer.set_label("dequant-3bit");
538
539        let device = w.device();
540
541        let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);
542
543        let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-3bit")?;
544
545        crate::metal_kernels::call_dequant_3bit(
546            device.device(),
547            &command_buffer,
548            &crate::metal_kernels::Kernels::new(),
549            s.dtype(),
550            w.buffer(),
551            s.buffer(),
552            z.buffer(),
553            self.h as u32,
554            self.w as u32,
555            &output,
556        )
557        .map_err(candle_core::Error::wrap)?;
558
559        let newstorage = candle_core::MetalStorage::new(
560            output,
561            device.clone(),
562            out_shape.elem_count(),
563            s.dtype(),
564        );
565        Ok((newstorage, out_shape))
566    }
567}