1#[cfg(feature = "metal")]
2use candle_core::{backend::BackendStorage, DType};
3use candle_core::{CpuStorage, CustomOp3, Layout, Result, Shape, WithDType};
4
5pub(crate) struct Dequant8Bit {
9 pub(crate) h: usize,
10 pub(crate) w: usize,
11}
12
13impl Dequant8Bit {
14 fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
15 let mut out = vec![T::default(); w.len()];
16 for (i, w) in w.iter().enumerate() {
17 let j = i % self.w;
18 out[i] = (T::from_f64(*w as f64) - z[j]) * s[j];
19 }
20 out
21 }
22}
23
24impl CustomOp3 for Dequant8Bit {
25 fn name(&self) -> &'static str {
26 "dequant-hqq-8bit"
27 }
28 fn cpu_fwd(
29 &self,
30 w: &CpuStorage,
31 l_w: &Layout,
32 s: &CpuStorage,
33 l_s: &Layout,
34 z: &CpuStorage,
35 l_z: &Layout,
36 ) -> Result<(CpuStorage, Shape)> {
37 let CpuStorage::U8(w_slice) = w else {
38 candle_core::bail!("Weight must be u8, HQQ dequant 8-bit");
39 };
40 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
41 candle_core::bail!("All inputs must be contiguous");
42 }
43 match (s, z) {
44 (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
45 CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
46 Shape::from_dims(&[self.h, self.w]),
47 )),
48 (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
49 CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
50 Shape::from_dims(&[self.h, self.w]),
51 )),
52 (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
53 CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
54 Shape::from_dims(&[self.h, self.w]),
55 )),
56 (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
57 }
58 }
59 #[cfg(feature = "metal")]
60 fn metal_fwd(
61 &self,
62 w: &candle_core::MetalStorage,
63 l_w: &Layout,
64 s: &candle_core::MetalStorage,
65 l_s: &Layout,
66 z: &candle_core::MetalStorage,
67 l_z: &Layout,
68 ) -> Result<(candle_core::MetalStorage, Shape)> {
69 if w.dtype() != DType::U8 {
70 candle_core::bail!("Weight must be u8, HQQ dequant 8-bit");
71 };
72 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
73 candle_core::bail!("All inputs must be contiguous");
74 }
75
76 let command_buffer = w.device().command_buffer()?;
77 command_buffer.set_label("dequant-8bit");
78
79 let device = w.device();
80
81 let out_shape = Shape::from_dims(&[self.h, self.w]);
82
83 let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-8bit")?;
84
85 crate::metal_kernels::call_dequant_8bit(
86 device.device(),
87 &command_buffer,
88 &crate::metal_kernels::Kernels::new(),
89 s.dtype(),
90 w.buffer(),
91 s.buffer(),
92 z.buffer(),
93 self.h as u32,
94 self.w as u32,
95 &output,
96 )
97 .map_err(candle_core::Error::wrap)?;
98
99 let newstorage = candle_core::MetalStorage::new(
100 output,
101 device.clone(),
102 out_shape.elem_count(),
103 s.dtype(),
104 );
105 Ok((newstorage, out_shape))
106 }
107}
108
109pub(crate) struct Dequant4Bit {
113 pub(crate) h: usize,
114 pub(crate) w: usize,
115}
116
117impl Dequant4Bit {
118 fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
119 let output_size = w.len() * 2;
120 let mut out = vec![T::default(); output_size];
121 for (i, w) in w.iter().enumerate() {
122 let j = i % self.w;
123 let nrows = self.h * self.w;
124 out[i] = (T::from_f64(((*w & 0xF0) >> 4) as f64) - z[j]) * s[j];
125 out[i + nrows] = (T::from_f64((*w & 0x0F) as f64) - z[j]) * s[j];
126 }
127 out
128 }
129}
130
131impl CustomOp3 for Dequant4Bit {
132 fn name(&self) -> &'static str {
133 "dequant-hqq-4bit"
134 }
135 fn cpu_fwd(
136 &self,
137 w: &CpuStorage,
138 l_w: &Layout,
139 s: &CpuStorage,
140 l_s: &Layout,
141 z: &CpuStorage,
142 l_z: &Layout,
143 ) -> Result<(CpuStorage, Shape)> {
144 const PACK_FACTOR: usize = 2;
145
146 let CpuStorage::U8(w_slice) = w else {
147 candle_core::bail!("Weight must be u8, HQQ dequant 4-bit");
148 };
149 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
150 candle_core::bail!("All inputs must be contiguous");
151 }
152 match (s, z) {
153 (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
154 CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
155 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
156 )),
157 (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
158 CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
159 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
160 )),
161 (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
162 CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
163 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
164 )),
165 (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
166 }
167 }
168 #[cfg(feature = "metal")]
169 fn metal_fwd(
170 &self,
171 w: &candle_core::MetalStorage,
172 l_w: &Layout,
173 s: &candle_core::MetalStorage,
174 l_s: &Layout,
175 z: &candle_core::MetalStorage,
176 l_z: &Layout,
177 ) -> Result<(candle_core::MetalStorage, Shape)> {
178 const PACK_FACTOR: usize = 2;
179
180 if w.dtype() != DType::U8 {
181 candle_core::bail!("Weight must be u8, HQQ dequant 4-bit");
182 };
183 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
184 candle_core::bail!("All inputs must be contiguous");
185 }
186
187 let command_buffer = w.device().command_buffer()?;
188 command_buffer.set_label("dequant-4bit");
189
190 let device = w.device();
191
192 let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);
193
194 let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-4bit")?;
195
196 crate::metal_kernels::call_dequant_4bit(
197 device.device(),
198 &command_buffer,
199 &crate::metal_kernels::Kernels::new(),
200 s.dtype(),
201 w.buffer(),
202 s.buffer(),
203 z.buffer(),
204 self.h as u32,
205 self.w as u32,
206 &output,
207 )
208 .map_err(candle_core::Error::wrap)?;
209
210 let newstorage = candle_core::MetalStorage::new(
211 output,
212 device.clone(),
213 out_shape.elem_count(),
214 s.dtype(),
215 );
216 Ok((newstorage, out_shape))
217 }
218}
219
220pub(crate) struct Dequant2Bit {
224 pub(crate) h: usize,
225 pub(crate) w: usize,
226}
227
228impl Dequant2Bit {
229 fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
230 let output_size = w.len() * 4;
231 let mut out = vec![T::default(); output_size];
232 for (i, w) in w.iter().enumerate() {
233 let j = i % self.w;
234 let nrows = self.h * self.w;
235 out[i] = (T::from_f64(((*w & 0xC0) >> 6) as f64) - z[j]) * s[j];
236 out[i + nrows] = (T::from_f64(((*w & 0x30) >> 4) as f64) - z[j]) * s[j];
237 out[i + nrows * 2] = (T::from_f64(((*w & 0x0C) >> 2) as f64) - z[j]) * s[j];
238 out[i + nrows * 3] = (T::from_f64((*w & 0x03) as f64) - z[j]) * s[j];
239 }
240 out
241 }
242}
243
244impl CustomOp3 for Dequant2Bit {
245 fn name(&self) -> &'static str {
246 "dequant-hqq-2bit"
247 }
248 fn cpu_fwd(
249 &self,
250 w: &CpuStorage,
251 l_w: &Layout,
252 s: &CpuStorage,
253 l_s: &Layout,
254 z: &CpuStorage,
255 l_z: &Layout,
256 ) -> Result<(CpuStorage, Shape)> {
257 const PACK_FACTOR: usize = 4;
258
259 let CpuStorage::U8(w_slice) = w else {
260 candle_core::bail!("Weight must be u8, HQQ dequant 2-bit");
261 };
262 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
263 candle_core::bail!("All inputs must be contiguous");
264 }
265 match (s, z) {
266 (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
267 CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
268 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
269 )),
270 (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
271 CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
272 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
273 )),
274 (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
275 CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
276 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
277 )),
278 (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
279 }
280 }
281 #[cfg(feature = "metal")]
282 fn metal_fwd(
283 &self,
284 w: &candle_core::MetalStorage,
285 l_w: &Layout,
286 s: &candle_core::MetalStorage,
287 l_s: &Layout,
288 z: &candle_core::MetalStorage,
289 l_z: &Layout,
290 ) -> Result<(candle_core::MetalStorage, Shape)> {
291 const PACK_FACTOR: usize = 4;
292
293 if w.dtype() != DType::U8 {
294 candle_core::bail!("Weight must be u8, HQQ dequant 2-bit");
295 };
296 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
297 candle_core::bail!("All inputs must be contiguous");
298 }
299
300 let command_buffer = w.device().command_buffer()?;
301 command_buffer.set_label("dequant-2bit");
302
303 let device = w.device();
304
305 let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);
306
307 let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-2bit")?;
308
309 crate::metal_kernels::call_dequant_2bit(
310 device.device(),
311 &command_buffer,
312 &crate::metal_kernels::Kernels::new(),
313 s.dtype(),
314 w.buffer(),
315 s.buffer(),
316 z.buffer(),
317 self.h as u32,
318 self.w as u32,
319 &output,
320 )
321 .map_err(candle_core::Error::wrap)?;
322
323 let newstorage = candle_core::MetalStorage::new(
324 output,
325 device.clone(),
326 out_shape.elem_count(),
327 s.dtype(),
328 );
329 Ok((newstorage, out_shape))
330 }
331}
332
333pub(crate) struct Dequant1Bit {
337 pub(crate) h: usize,
338 pub(crate) w: usize,
339}
340
341impl Dequant1Bit {
342 fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
343 let output_size = w.len() * 8;
344 let mut out = vec![T::default(); output_size];
345 for (i, w) in w.iter().enumerate() {
346 let j = i % self.w;
347 let nrows = self.h * self.w;
348 out[i] = (T::from_f64(((*w & 0x80) >> 7) as f64) - z[j]) * s[j];
349 out[i + nrows] = (T::from_f64(((*w & 0x40) >> 6) as f64) - z[j]) * s[j];
350 out[i + nrows * 2] = (T::from_f64(((*w & 0x20) >> 5) as f64) - z[j]) * s[j];
351 out[i + nrows * 3] = (T::from_f64(((*w & 0x10) >> 4) as f64) - z[j]) * s[j];
352 out[i + nrows * 4] = (T::from_f64(((*w & 0x08) >> 3) as f64) - z[j]) * s[j];
353 out[i + nrows * 5] = (T::from_f64(((*w & 0x04) >> 2) as f64) - z[j]) * s[j];
354 out[i + nrows * 6] = (T::from_f64(((*w & 0x02) >> 1) as f64) - z[j]) * s[j];
355 out[i + nrows * 7] = (T::from_f64((*w & 0x01) as f64) - z[j]) * s[j];
356 }
357 out
358 }
359}
360
361impl CustomOp3 for Dequant1Bit {
362 fn name(&self) -> &'static str {
363 "dequant-hqq-1bit"
364 }
365 fn cpu_fwd(
366 &self,
367 w: &CpuStorage,
368 l_w: &Layout,
369 s: &CpuStorage,
370 l_s: &Layout,
371 z: &CpuStorage,
372 l_z: &Layout,
373 ) -> Result<(CpuStorage, Shape)> {
374 const PACK_FACTOR: usize = 8;
375
376 let CpuStorage::U8(w_slice) = w else {
377 candle_core::bail!("Weight must be u8, HQQ dequant 1-bit");
378 };
379 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
380 candle_core::bail!("All inputs must be contiguous");
381 }
382 match (s, z) {
383 (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
384 CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
385 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
386 )),
387 (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
388 CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
389 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
390 )),
391 (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
392 CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
393 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
394 )),
395 (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
396 }
397 }
398 #[cfg(feature = "metal")]
399 fn metal_fwd(
400 &self,
401 w: &candle_core::MetalStorage,
402 l_w: &Layout,
403 s: &candle_core::MetalStorage,
404 l_s: &Layout,
405 z: &candle_core::MetalStorage,
406 l_z: &Layout,
407 ) -> Result<(candle_core::MetalStorage, Shape)> {
408 const PACK_FACTOR: usize = 8;
409
410 if w.dtype() != DType::U8 {
411 candle_core::bail!("Weight must be u8, HQQ dequant 1-bit");
412 };
413 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
414 candle_core::bail!("All inputs must be contiguous");
415 }
416
417 let command_buffer = w.device().command_buffer()?;
418 command_buffer.set_label("dequant-1bit");
419
420 let device = w.device();
421
422 let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);
423
424 let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-1bit")?;
425
426 crate::metal_kernels::call_dequant_1bit(
427 device.device(),
428 &command_buffer,
429 &crate::metal_kernels::Kernels::new(),
430 s.dtype(),
431 w.buffer(),
432 s.buffer(),
433 z.buffer(),
434 self.h as u32,
435 self.w as u32,
436 &output,
437 )
438 .map_err(candle_core::Error::wrap)?;
439
440 let newstorage = candle_core::MetalStorage::new(
441 output,
442 device.clone(),
443 out_shape.elem_count(),
444 s.dtype(),
445 );
446 Ok((newstorage, out_shape))
447 }
448}
449
450pub(crate) struct Dequant3Bit {
454 pub(crate) h: usize,
455 pub(crate) w: usize,
456}
457
458impl Dequant3Bit {
459 fn dequantize<T: WithDType + Default>(&self, w: &[i32], s: &[T], z: &[T]) -> Vec<T> {
460 let output_size = w.len() * 10;
461 let mut out = vec![T::default(); output_size];
462 for (i, w) in w.iter().enumerate() {
463 let j = i % self.w;
464 let nrows = self.h * self.w;
465 out[i] = (T::from_f64(((*w & 0x38000000) >> 27) as f64) - z[j]) * s[j];
466 out[i + nrows] = (T::from_f64(((*w & 0x07000000) >> 24) as f64) - z[j]) * s[j];
467 out[i + nrows * 2] = (T::from_f64(((*w & 0x00E00000) >> 21) as f64) - z[j]) * s[j];
468 out[i + nrows * 3] = (T::from_f64(((*w & 0x001C0000) >> 18) as f64) - z[j]) * s[j];
469 out[i + nrows * 4] = (T::from_f64(((*w & 0x00038000) >> 15) as f64) - z[j]) * s[j];
470 out[i + nrows * 5] = (T::from_f64(((*w & 0x00007000) >> 12) as f64) - z[j]) * s[j];
471 out[i + nrows * 6] = (T::from_f64(((*w & 0x00000E00) >> 9) as f64) - z[j]) * s[j];
472 out[i + nrows * 7] = (T::from_f64(((*w & 0x000001C0) >> 6) as f64) - z[j]) * s[j];
473 out[i + nrows * 8] = (T::from_f64(((*w & 0x00000038) >> 3) as f64) - z[j]) * s[j];
474 out[i + nrows * 9] = (T::from_f64((*w & 0x00000007) as f64) - z[j]) * s[j];
475 }
476 out
477 }
478}
479
480impl CustomOp3 for Dequant3Bit {
481 fn name(&self) -> &'static str {
482 "dequant-hqq-3bit"
483 }
484 fn cpu_fwd(
485 &self,
486 w: &CpuStorage,
487 l_w: &Layout,
488 s: &CpuStorage,
489 l_s: &Layout,
490 z: &CpuStorage,
491 l_z: &Layout,
492 ) -> Result<(CpuStorage, Shape)> {
493 const PACK_FACTOR: usize = 10;
494
495 let CpuStorage::I32(w_slice) = w else {
496 candle_core::bail!("Weight must be i32, HQQ dequant 3-bit");
497 };
498 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
499 candle_core::bail!("All inputs must be contiguous");
500 }
501 match (s, z) {
502 (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok((
503 CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)),
504 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
505 )),
506 (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok((
507 CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)),
508 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
509 )),
510 (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok((
511 CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)),
512 Shape::from_dims(&[PACK_FACTOR * self.h, self.w]),
513 )),
514 (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"),
515 }
516 }
517 #[cfg(feature = "metal")]
518 fn metal_fwd(
519 &self,
520 w: &candle_core::MetalStorage,
521 l_w: &Layout,
522 s: &candle_core::MetalStorage,
523 l_s: &Layout,
524 z: &candle_core::MetalStorage,
525 l_z: &Layout,
526 ) -> Result<(candle_core::MetalStorage, Shape)> {
527 const PACK_FACTOR: usize = 10;
528
529 if w.dtype() != DType::I32 {
530 candle_core::bail!("Weight must be i32, HQQ dequant 3-bit");
531 };
532 if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) {
533 candle_core::bail!("All inputs must be contiguous");
534 }
535
536 let command_buffer = w.device().command_buffer()?;
537 command_buffer.set_label("dequant-3bit");
538
539 let device = w.device();
540
541 let out_shape = Shape::from_dims(&[PACK_FACTOR * self.h, self.w]);
542
543 let output = device.new_buffer(out_shape.elem_count(), s.dtype(), "dequant-3bit")?;
544
545 crate::metal_kernels::call_dequant_3bit(
546 device.device(),
547 &command_buffer,
548 &crate::metal_kernels::Kernels::new(),
549 s.dtype(),
550 w.buffer(),
551 s.buffer(),
552 z.buffer(),
553 self.h as u32,
554 self.w as u32,
555 &output,
556 )
557 .map_err(candle_core::Error::wrap)?;
558
559 let newstorage = candle_core::MetalStorage::new(
560 output,
561 device.clone(),
562 out_shape.elem_count(),
563 s.dtype(),
564 );
565 Ok((newstorage, out_shape))
566 }
567}