1#![allow(unused)]
2
3use candle_core::{
4 backend::BackendStorage, from_storage_no_op, DType, MetalStorage, Result, Shape, Storage,
5 Tensor, D,
6};
7
8use super::{AfqBits, AfqGroupSize};
9
10pub(crate) fn afq_quantize_op(
12 w: &Tensor,
13 group_size: AfqGroupSize,
14 bits: AfqBits,
15) -> Result<(Tensor, Tensor, Tensor)> {
16 let group_size = group_size as usize;
17 let bits = bits as usize;
18
19 if w.rank() < 2 {
20 candle_core::bail!("AFQ quantize expects weight matrix of at least rank 2");
21 }
22 if w.dim(D::Minus1)? % group_size != 0 {
23 candle_core::bail!(
24 "Last dim of weight matrix ({:?}) must be divisible by group size {group_size}.",
25 w.dims()
26 );
27 }
28
29 #[cfg(feature = "metal")]
30 {
31 let w_s = w.storage_and_layout().0;
32 let Storage::Metal(w_s) = &*w_s else {
33 candle_core::bail!("expected metal")
34 };
35 let device = w_s.device();
36
37 let command_buffer = device.command_buffer()?;
38 command_buffer.set_label("afq-quantize");
39
40 let mut wq_shape = w.dims().to_vec();
41 *wq_shape.last_mut().unwrap() = w.dim(D::Minus1)? * bits / 32;
42 let mut s_shape = w.dims().to_vec();
43 *s_shape.last_mut().unwrap() = w.dim(D::Minus1)? / group_size;
44
45 let output =
46 device.new_buffer(wq_shape.iter().product(), DType::U32, "afq-quantize-output")?;
47 let scales =
48 device.new_buffer(s_shape.iter().product(), w.dtype(), "afq-quantize-scales")?;
49 let biases =
50 device.new_buffer(s_shape.iter().product(), w.dtype(), "afq-quantize-biases")?;
51
52 assert_eq!(w.layout().start_offset(), 0);
53 crate::metal_kernels::call_affine_quantize(
54 device.device(),
55 &command_buffer,
56 &crate::metal_kernels::Kernels::new(),
57 w.dtype(),
58 w_s.buffer(),
59 w.dims(),
60 w.stride(),
61 &output,
62 &wq_shape,
63 &scales,
64 &biases,
65 false,
66 group_size,
67 bits,
68 )
69 .map_err(candle_core::Error::wrap)?;
70
71 let output = from_storage_no_op(
72 Storage::Metal(MetalStorage::new(
73 output,
74 device.clone(),
75 wq_shape.iter().product(),
76 DType::U32,
77 )),
78 wq_shape,
79 false,
80 );
81 let scales = from_storage_no_op(
82 Storage::Metal(MetalStorage::new(
83 scales,
84 device.clone(),
85 s_shape.iter().product(),
86 w.dtype(),
87 )),
88 s_shape.clone(),
89 false,
90 );
91 let biases = from_storage_no_op(
92 Storage::Metal(MetalStorage::new(
93 biases,
94 device.clone(),
95 s_shape.iter().product(),
96 w.dtype(),
97 )),
98 s_shape,
99 false,
100 );
101
102 Ok((output, scales, biases))
103 }
104 #[cfg(not(feature = "metal"))]
105 {
106 candle_core::bail!("`afq_quantize_op` only works on Metal.")
107 }
108}
109
110pub(crate) fn afq_dequantize_op(
111 w_q: &Tensor,
112 scales: &Tensor,
113 biases: &Tensor,
114 group_size: AfqGroupSize,
115 bits: AfqBits,
116) -> Result<Tensor> {
117 let group_size = group_size as usize;
118 let bits = bits as usize;
119
120 if w_q.rank() < 2 || scales.rank() < 2 || biases.rank() < 2 {
121 candle_core::bail!("AFQ dequantize expects all matrices of at least rank 2");
122 }
123
124 #[cfg(feature = "metal")]
125 {
126 let wq_s = w_q.storage_and_layout().0;
127 let Storage::Metal(wq_s) = &*wq_s else {
128 candle_core::bail!("expected metal")
129 };
130 let s_s = scales.storage_and_layout().0;
131 let Storage::Metal(s_s) = &*s_s else {
132 candle_core::bail!("expected metal")
133 };
134 let b_s = biases.storage_and_layout().0;
135 let Storage::Metal(b_s) = &*b_s else {
136 candle_core::bail!("expected metal")
137 };
138
139 let device = wq_s.device();
140
141 let command_buffer = device.command_buffer()?;
142 command_buffer.set_label("afq-dequantize");
143
144 let out_size = w_q.dim(D::Minus1)? * 32 / bits;
145 let mut w_shape = w_q.dims().to_vec();
146 *w_shape.last_mut().unwrap() = out_size;
147
148 if out_size != scales.dim(D::Minus1)? * group_size
149 || out_size != biases.dim(D::Minus1)? * group_size
150 {
151 candle_core::bail!(
152 "Scales and biases do not match the matrix given dequantization parameters."
153 );
154 }
155
156 let output = device.new_buffer(
157 w_shape.iter().product(),
158 scales.dtype(),
159 "afq-dequantize-output",
160 )?;
161
162 assert_eq!(w_q.layout().start_offset(), 0);
163 assert_eq!(scales.layout().start_offset(), 0);
164 assert_eq!(biases.layout().start_offset(), 0);
165 crate::metal_kernels::call_affine_quantize(
166 device.device(),
167 &command_buffer,
168 &crate::metal_kernels::Kernels::new(),
169 scales.dtype(),
170 wq_s.buffer(),
171 w_q.dims(),
172 w_q.stride(),
173 &output,
174 &w_shape,
175 s_s.buffer(),
176 b_s.buffer(),
177 true,
178 group_size,
179 bits,
180 )
181 .map_err(candle_core::Error::wrap)?;
182
183 let output = from_storage_no_op(
184 Storage::Metal(MetalStorage::new(
185 output,
186 device.clone(),
187 w_shape.iter().product(),
188 scales.dtype(),
189 )),
190 w_shape,
191 false,
192 );
193
194 Ok(output)
195 }
196 #[cfg(not(feature = "metal"))]
197 {
198 candle_core::bail!("`afq_dequantize_op` only works on Metal.")
199 }
200}
201
202fn make_dummy_indices(x: &Tensor) -> Result<Tensor> {
203 let x_batches = x
204 .dims()
205 .iter()
206 .take(x.rank() - 2)
207 .copied()
208 .collect::<Vec<_>>();
209
210 (Tensor::ones(x_batches.iter().product::<usize>(), DType::F32, x.device())?
211 .cumsum(0)?
212 .to_dtype(DType::U32)?
213 - 1.)?
214 .reshape(x_batches)
215}
216
217#[allow(clippy::too_many_arguments)]
219pub(crate) fn afq_mm_op(
220 x: &Tensor,
221 w: &Tensor,
222 scales: &Tensor,
223 biases: &Tensor,
224 lhs_indices: Option<&Tensor>,
225 rhs_indices: Option<&Tensor>,
226 group_size: AfqGroupSize,
227 bits: AfqBits,
228 transpose: bool,
229) -> Result<Tensor> {
230 let group_size = group_size as usize;
231 let bits = bits as usize;
232
233 let w_outer_dims = {
234 if w.dtype() != DType::U32 {
235 candle_core::bail!("AFQ weight matrix must be u32");
236 }
237 if scales.dims() != biases.dims() {
238 candle_core::bail!("Scales and biases should have the same shapes");
239 }
240 if w.dim(D::Minus1)? * 32 / bits != scales.dim(D::Minus1)? * group_size {
241 candle_core::bail!("Last dims of w and scales must be compatible.");
242 }
243
244 let x_inner_dims = x.dim(D::Minus1)?;
245
246 let w_inner_dims = if transpose {
248 w.dim(D::Minus1)? * 32 / bits
249 } else {
250 w.dim(D::Minus2)?
251 };
252 let w_outer_dims = if transpose {
253 w.dim(D::Minus2)?
254 } else {
255 w.dim(D::Minus1)? * 32 / bits
256 };
257
258 if w_inner_dims != x_inner_dims {
259 candle_core::bail!(
260 "w inner dims ({:?}) must match x inner dims ({:?}). transpose={transpose}",
261 w.dims(),
262 x.dims()
263 );
264 }
265
266 w_outer_dims
267 };
268
269 #[cfg(feature = "metal")]
270 {
271 let x_s = x.storage_and_layout().0;
272 let Storage::Metal(x_s) = &*x_s else {
273 candle_core::bail!("expected metal")
274 };
275 let w_s = w.storage_and_layout().0;
276 let Storage::Metal(w_s) = &*w_s else {
277 candle_core::bail!("expected metal")
278 };
279 let s_s = scales.storage_and_layout().0;
280 let Storage::Metal(s_s) = &*s_s else {
281 candle_core::bail!("expected metal")
282 };
283 let b_s = biases.storage_and_layout().0;
284 let Storage::Metal(b_s) = &*b_s else {
285 candle_core::bail!("expected metal")
286 };
287
288 let device = w_s.device();
289
290 let command_buffer = device.command_buffer()?;
291 command_buffer.set_label("afq-qmm");
292
293 assert_eq!(x.layout().start_offset(), 0);
294 assert_eq!(w.layout().start_offset(), 0);
295 assert_eq!(scales.layout().start_offset(), 0);
296 assert_eq!(biases.layout().start_offset(), 0);
297
298 let (output, out_shape) = if lhs_indices.is_some() || rhs_indices.is_some() {
299 let mut lhs_indices = match lhs_indices {
300 Some(lhs_indices) => lhs_indices.clone(),
301 None => make_dummy_indices(x)?,
302 };
303 let mut rhs_indices = match rhs_indices {
304 Some(rhs_indices) => rhs_indices.clone(),
305 None => make_dummy_indices(w)?,
306 };
307 if lhs_indices.dtype() != DType::U32 || rhs_indices.dtype() != DType::U32 {
308 candle_core::bail!("lhs and rhs indices must be u32.")
309 }
310 {
312 let mut shape = lhs_indices.shape().clone();
313 shape = rhs_indices
314 .shape()
315 .broadcast_shape_binary_op(rhs_indices.shape(), "afq-qmm")?;
316 lhs_indices = lhs_indices.broadcast_as(shape.clone())?;
317 rhs_indices = rhs_indices.broadcast_as(shape)?;
318 }
319
320 let li_s = lhs_indices.storage_and_layout().0;
321 let Storage::Metal(li_s) = &*li_s else {
322 candle_core::bail!("expected metal")
323 };
324 let ri_s = rhs_indices.storage_and_layout().0;
325 let Storage::Metal(ri_s) = &*ri_s else {
326 candle_core::bail!("expected metal")
327 };
328
329 let mut out_shape = lhs_indices.dims().to_vec();
330 out_shape.push(x.dim(D::Minus2)?);
331 out_shape.push(w_outer_dims);
332
333 let output =
334 device.new_buffer(out_shape.iter().product(), scales.dtype(), "afq-qmm-output")?;
335
336 crate::metal_kernels::call_afq_qmm(
337 device.device(),
338 &command_buffer,
339 &crate::metal_kernels::Kernels::new(),
340 scales.dtype(),
341 x_s.buffer(),
342 x.dims(),
343 x.stride(),
344 w_s.buffer(),
345 w.dims(),
346 w.stride(),
347 s_s.buffer(),
348 scales.stride(),
349 b_s.buffer(),
350 biases.stride(),
351 &output,
352 &out_shape,
353 Some((li_s.buffer(), ri_s.buffer())),
354 Some(lhs_indices.dims()),
355 Some((lhs_indices.stride(), rhs_indices.stride())),
356 transpose,
357 bits,
358 group_size,
359 )
360 .map_err(candle_core::Error::wrap)?;
361
362 (output, out_shape)
363 } else {
364 let mut out_shape = x.dims().to_vec();
365 *out_shape.last_mut().unwrap() = w_outer_dims;
366
367 let output =
368 device.new_buffer(out_shape.iter().product(), scales.dtype(), "afq-qmm-output")?;
369
370 crate::metal_kernels::call_afq_qmm(
371 device.device(),
372 &command_buffer,
373 &crate::metal_kernels::Kernels::new(),
374 scales.dtype(),
375 x_s.buffer(),
376 x.dims(),
377 x.stride(),
378 w_s.buffer(),
379 w.dims(),
380 w.stride(),
381 s_s.buffer(),
382 scales.stride(),
383 b_s.buffer(),
384 biases.stride(),
385 &output,
386 &out_shape,
387 None,
388 None,
389 None,
390 transpose,
391 bits,
392 group_size,
393 )
394 .map_err(candle_core::Error::wrap)?;
395
396 (output, out_shape)
397 };
398
399 let output = from_storage_no_op(
400 Storage::Metal(MetalStorage::new(
401 output,
402 device.clone(),
403 out_shape.iter().product(),
404 scales.dtype(),
405 )),
406 out_shape,
407 false,
408 );
409
410 Ok(output)
411 }
412 #[cfg(not(feature = "metal"))]
413 {
414 candle_core::bail!("`afq_mm_op` only works on Metal.")
415 }
416}
417
418#[cfg(feature = "metal")]
419#[cfg(test)]
420mod metal_tests {
421 use candle_core::{DType, Device, Result, Tensor, D};
422
423 use crate::{afq::ops::afq_dequantize_op, AfqBits, AfqGroupSize};
424
425 use super::afq_quantize_op;
426
427 fn run_afq_roundtrip(bits: AfqBits) -> Result<f32> {
428 let device = Device::new_metal(0)?;
429 let group_size = AfqGroupSize::Low;
430
431 let xs = Tensor::randn(0f32, 1f32, (32, 32), &device)?;
432
433 let (w_q, scales, biases) = afq_quantize_op(&xs, group_size, bits)?;
434
435 let ys = afq_dequantize_op(&w_q, &scales, &biases, group_size, bits)?;
440
441 let rmse = (xs - ys)?
446 .sqr()?
447 .mean(D::Minus1)?
448 .sqrt()?
449 .mean_all()?
450 .to_dtype(DType::F32)?
451 .to_scalar::<f32>()?;
452
453 Ok(rmse)
454 }
455
456 #[test]
457 fn test_afq_eight() -> Result<()> {
458 let rmse = run_afq_roundtrip(AfqBits::Eight)?;
459 assert!(rmse < 0.005, "{rmse}");
460 Ok(())
461 }
462
463 #[test]
464 fn test_afq_six() -> Result<()> {
465 let rmse = run_afq_roundtrip(AfqBits::Six)?;
466 assert!(rmse < 0.02, "{rmse}");
467 Ok(())
468 }
469
470 #[test]
471 fn test_afq_four() -> Result<()> {
472 let rmse = run_afq_roundtrip(AfqBits::Four)?;
473 assert!(rmse < 0.078, "{rmse}");
474 Ok(())
475 }
476
477 #[test]
478 fn test_afq_three() -> Result<()> {
479 let rmse = run_afq_roundtrip(AfqBits::Three)?;
480 assert!(rmse < 0.17, "{rmse}");
481 Ok(())
482 }
483
484 #[test]
485 fn test_afq_two() -> Result<()> {
486 let rmse = run_afq_roundtrip(AfqBits::Two)?;
487 assert!(rmse < 0.35, "{rmse}");
488 Ok(())
489 }
490}