1#![allow(unused)]
2
3use candle_core::{
4 backend::BackendStorage, from_storage_no_op, DType, MetalStorage, Result, Storage, Tensor, D,
5};
6
7use super::{AfqBits, AfqGroupSize};
8
9pub(crate) fn afq_quantize_op(
11 w: &Tensor,
12 group_size: AfqGroupSize,
13 bits: AfqBits,
14) -> Result<(Tensor, Tensor, Tensor)> {
15 let group_size = group_size as usize;
16 let bits = bits as usize;
17
18 if w.rank() < 2 {
19 candle_core::bail!("AFQ quantize expects weight matrix of at least rank 2");
20 }
21 if w.dim(D::Minus1)? % group_size != 0 {
22 candle_core::bail!(
23 "Last dim of weight matrix ({:?}) must be divisible by group size {group_size}.",
24 w.dims()
25 );
26 }
27
28 #[cfg(feature = "metal")]
29 {
30 let w_s = w.storage_and_layout().0;
31 let Storage::Metal(w_s) = &*w_s else {
32 candle_core::bail!("expected metal")
33 };
34 let device = w_s.device();
35
36 let command_buffer = device.command_buffer()?;
37 command_buffer.set_label("afq-quantize");
38
39 let mut wq_shape = w.dims().to_vec();
40 *wq_shape.last_mut().unwrap() = w.dim(D::Minus1)? * bits / 32;
41 let mut s_shape = w.dims().to_vec();
42 *s_shape.last_mut().unwrap() = w.dim(D::Minus1)? / group_size;
43
44 let output =
45 device.new_buffer(wq_shape.iter().product(), DType::U32, "afq-quantize-output")?;
46 let scales =
47 device.new_buffer(s_shape.iter().product(), w.dtype(), "afq-quantize-scales")?;
48 let biases =
49 device.new_buffer(s_shape.iter().product(), w.dtype(), "afq-quantize-biases")?;
50
51 assert_eq!(w.layout().start_offset(), 0);
52 crate::metal_kernels::call_affine_quantize(
53 device.device(),
54 &command_buffer,
55 &crate::metal_kernels::Kernels::new(),
56 w.dtype(),
57 w_s.buffer(),
58 w.dims(),
59 w.stride(),
60 &output,
61 &wq_shape,
62 &scales,
63 &biases,
64 false,
65 group_size,
66 bits,
67 )
68 .map_err(candle_core::Error::wrap)?;
69
70 let output = from_storage_no_op(
71 Storage::Metal(MetalStorage::new(
72 output,
73 device.clone(),
74 wq_shape.iter().product(),
75 DType::U32,
76 )),
77 wq_shape,
78 false,
79 );
80 let scales = from_storage_no_op(
81 Storage::Metal(MetalStorage::new(
82 scales,
83 device.clone(),
84 s_shape.iter().product(),
85 w.dtype(),
86 )),
87 s_shape.clone(),
88 false,
89 );
90 let biases = from_storage_no_op(
91 Storage::Metal(MetalStorage::new(
92 biases,
93 device.clone(),
94 s_shape.iter().product(),
95 w.dtype(),
96 )),
97 s_shape,
98 false,
99 );
100
101 Ok((output, scales, biases))
102 }
103 #[cfg(not(feature = "metal"))]
104 {
105 candle_core::bail!("`afq_quantize_op` only works on Metal.")
106 }
107}
108
109pub(crate) fn afq_dequantize_op(
110 w_q: &Tensor,
111 scales: &Tensor,
112 biases: &Tensor,
113 group_size: AfqGroupSize,
114 bits: AfqBits,
115) -> Result<Tensor> {
116 let group_size = group_size as usize;
117 let bits = bits as usize;
118
119 if w_q.rank() < 2 || scales.rank() < 2 || biases.rank() < 2 {
120 candle_core::bail!("AFQ dequantize expects all matrices of at least rank 2");
121 }
122
123 #[cfg(feature = "metal")]
124 {
125 let wq_s = w_q.storage_and_layout().0;
126 let Storage::Metal(wq_s) = &*wq_s else {
127 candle_core::bail!("expected metal")
128 };
129 let s_s = scales.storage_and_layout().0;
130 let Storage::Metal(s_s) = &*s_s else {
131 candle_core::bail!("expected metal")
132 };
133 let b_s = biases.storage_and_layout().0;
134 let Storage::Metal(b_s) = &*b_s else {
135 candle_core::bail!("expected metal")
136 };
137
138 let device = wq_s.device();
139
140 let command_buffer = device.command_buffer()?;
141 command_buffer.set_label("afq-dequantize");
142
143 let out_size = w_q.dim(D::Minus1)? * 32 / bits;
144 let mut w_shape = w_q.dims().to_vec();
145 *w_shape.last_mut().unwrap() = out_size;
146
147 if out_size != scales.dim(D::Minus1)? * group_size
148 || out_size != biases.dim(D::Minus1)? * group_size
149 {
150 candle_core::bail!(
151 "Scales and biases do not match the matrix given dequantization parameters."
152 );
153 }
154
155 let output = device.new_buffer(
156 w_shape.iter().product(),
157 scales.dtype(),
158 "afq-dequantize-output",
159 )?;
160
161 assert_eq!(w_q.layout().start_offset(), 0);
162 assert_eq!(scales.layout().start_offset(), 0);
163 assert_eq!(biases.layout().start_offset(), 0);
164 crate::metal_kernels::call_affine_quantize(
165 device.device(),
166 &command_buffer,
167 &crate::metal_kernels::Kernels::new(),
168 scales.dtype(),
169 wq_s.buffer(),
170 w_q.dims(),
171 w_q.stride(),
172 &output,
173 &w_shape,
174 s_s.buffer(),
175 b_s.buffer(),
176 true,
177 group_size,
178 bits,
179 )
180 .map_err(candle_core::Error::wrap)?;
181
182 let output = from_storage_no_op(
183 Storage::Metal(MetalStorage::new(
184 output,
185 device.clone(),
186 w_shape.iter().product(),
187 scales.dtype(),
188 )),
189 w_shape,
190 false,
191 );
192
193 Ok(output)
194 }
195 #[cfg(not(feature = "metal"))]
196 {
197 candle_core::bail!("`afq_dequantize_op` only works on Metal.")
198 }
199}
200
201pub(crate) fn afq_mm_op(
202 x: &Tensor,
203 w: &Tensor,
204 scales: &Tensor,
205 biases: &Tensor,
206 group_size: AfqGroupSize,
207 bits: AfqBits,
208 transpose: bool,
209) -> Result<Tensor> {
210 let group_size = group_size as usize;
211 let bits = bits as usize;
212
213 let w_outer_dims = {
214 if w.dtype() != DType::U32 {
215 candle_core::bail!("AFQ weight matrix must be u32");
216 }
217 if scales.dims() != biases.dims() {
218 candle_core::bail!("Scales and biases should have the same shapes");
219 }
220 if w.dim(D::Minus1)? * 32 / bits != scales.dim(D::Minus1)? * group_size {
221 candle_core::bail!("Last dims of w and scales must be compatible.");
222 }
223
224 let x_inner_dims = x.dim(D::Minus1)?;
225
226 let w_inner_dims = if transpose {
228 w.dim(D::Minus1)? * 32 / bits
229 } else {
230 w.dim(D::Minus2)?
231 };
232 let w_outer_dims = if transpose {
233 w.dim(D::Minus2)?
234 } else {
235 w.dim(D::Minus1)? * 32 / bits
236 };
237
238 if w_inner_dims != x_inner_dims {
239 candle_core::bail!(
240 "w inner dims ({:?}) must match x inner dims ({:?}). transpose={transpose}",
241 w.dims(),
242 x.dims()
243 );
244 }
245
246 w_outer_dims
247 };
248
249 #[cfg(feature = "metal")]
250 {
251 let x_s = x.storage_and_layout().0;
252 let Storage::Metal(x_s) = &*x_s else {
253 candle_core::bail!("expected metal")
254 };
255 let w_s = w.storage_and_layout().0;
256 let Storage::Metal(w_s) = &*w_s else {
257 candle_core::bail!("expected metal")
258 };
259 let s_s = scales.storage_and_layout().0;
260 let Storage::Metal(s_s) = &*s_s else {
261 candle_core::bail!("expected metal")
262 };
263 let b_s = biases.storage_and_layout().0;
264 let Storage::Metal(b_s) = &*b_s else {
265 candle_core::bail!("expected metal")
266 };
267
268 let device = w_s.device();
269
270 let command_buffer = device.command_buffer()?;
271 command_buffer.set_label("afq-dequantize");
272
273 let mut out_shape = x.dims().to_vec();
274 *out_shape.last_mut().unwrap() = w_outer_dims;
275
276 let output =
277 device.new_buffer(out_shape.iter().product(), scales.dtype(), "afq-qmm-output")?;
278
279 assert_eq!(x.layout().start_offset(), 0);
280 assert_eq!(w.layout().start_offset(), 0);
281 assert_eq!(scales.layout().start_offset(), 0);
282 assert_eq!(biases.layout().start_offset(), 0);
283
284 crate::metal_kernels::call_afq_qmm(
285 device.device(),
286 &command_buffer,
287 &crate::metal_kernels::Kernels::new(),
288 scales.dtype(),
289 x_s.buffer(),
290 x.dims(),
291 x.stride(),
292 w_s.buffer(),
293 w.dims(),
294 w.stride(),
295 s_s.buffer(),
296 scales.stride(),
297 b_s.buffer(),
298 biases.stride(),
299 &output,
300 &out_shape,
301 transpose,
302 bits,
303 group_size,
304 )
305 .map_err(candle_core::Error::wrap)?;
306
307 let output = from_storage_no_op(
308 Storage::Metal(MetalStorage::new(
309 output,
310 device.clone(),
311 out_shape.iter().product(),
312 scales.dtype(),
313 )),
314 out_shape,
315 false,
316 );
317
318 Ok(output)
319 }
320 #[cfg(not(feature = "metal"))]
321 {
322 candle_core::bail!("`afq_mm_op` only works on Metal.")
323 }
324}
325
326#[cfg(feature = "metal")]
327#[cfg(test)]
328mod metal_tests {
329 use candle_core::{DType, Device, Result, Tensor, D};
330
331 use crate::{afq::ops::afq_dequantize_op, AfqBits, AfqGroupSize};
332
333 use super::afq_quantize_op;
334
335 fn run_afq_roundtrip(bits: AfqBits) -> Result<f32> {
336 let device = Device::new_metal(0)?;
337 let group_size = AfqGroupSize::Low;
338
339 let xs = Tensor::randn(0f32, 1f32, (32, 32), &device)?;
340
341 let (w_q, scales, biases) = afq_quantize_op(&xs, group_size, bits)?;
342
343 let ys = afq_dequantize_op(&w_q, &scales, &biases, group_size, bits)?;
348
349 let rmse = (xs - ys)?
354 .sqr()?
355 .mean(D::Minus1)?
356 .sqrt()?
357 .mean_all()?
358 .to_dtype(DType::F32)?
359 .to_scalar::<f32>()?;
360
361 Ok(rmse)
362 }
363
364 #[test]
365 fn test_afq_eight() -> Result<()> {
366 let rmse = run_afq_roundtrip(AfqBits::Eight)?;
367 assert!(rmse < 0.005, "{rmse}");
368 Ok(())
369 }
370
371 #[test]
372 fn test_afq_six() -> Result<()> {
373 let rmse = run_afq_roundtrip(AfqBits::Six)?;
374 assert!(rmse < 0.02, "{rmse}");
375 Ok(())
376 }
377
378 #[test]
379 fn test_afq_four() -> Result<()> {
380 let rmse = run_afq_roundtrip(AfqBits::Four)?;
381 assert!(rmse < 0.078, "{rmse}");
382 Ok(())
383 }
384
385 #[test]
386 fn test_afq_three() -> Result<()> {
387 let rmse = run_afq_roundtrip(AfqBits::Three)?;
388 assert!(rmse < 0.17, "{rmse}");
389 Ok(())
390 }
391
392 #[test]
393 fn test_afq_two() -> Result<()> {
394 let rmse = run_afq_roundtrip(AfqBits::Two)?;
395 assert!(rmse < 0.35, "{rmse}");
396 Ok(())
397 }
398}