mistralrs_quant/afq/
ops.rs

1#![allow(unused)]
2
3use candle_core::{
4    backend::BackendStorage, from_storage_no_op, DType, MetalStorage, Result, Storage, Tensor, D,
5};
6
7use super::{AfqBits, AfqGroupSize};
8
9/// Returns (w_q, scales, biases)
10pub(crate) fn afq_quantize_op(
11    w: &Tensor,
12    group_size: AfqGroupSize,
13    bits: AfqBits,
14) -> Result<(Tensor, Tensor, Tensor)> {
15    let group_size = group_size as usize;
16    let bits = bits as usize;
17
18    if w.rank() < 2 {
19        candle_core::bail!("AFQ quantize expects weight matrix of at least rank 2");
20    }
21    if w.dim(D::Minus1)? % group_size != 0 {
22        candle_core::bail!(
23            "Last dim of weight matrix ({:?}) must be divisible by group size {group_size}.",
24            w.dims()
25        );
26    }
27
28    #[cfg(feature = "metal")]
29    {
30        let w_s = w.storage_and_layout().0;
31        let Storage::Metal(w_s) = &*w_s else {
32            candle_core::bail!("expected metal")
33        };
34        let device = w_s.device();
35
36        let command_buffer = device.command_buffer()?;
37        command_buffer.set_label("afq-quantize");
38
39        let mut wq_shape = w.dims().to_vec();
40        *wq_shape.last_mut().unwrap() = w.dim(D::Minus1)? * bits / 32;
41        let mut s_shape = w.dims().to_vec();
42        *s_shape.last_mut().unwrap() = w.dim(D::Minus1)? / group_size;
43
44        let output =
45            device.new_buffer(wq_shape.iter().product(), DType::U32, "afq-quantize-output")?;
46        let scales =
47            device.new_buffer(s_shape.iter().product(), w.dtype(), "afq-quantize-scales")?;
48        let biases =
49            device.new_buffer(s_shape.iter().product(), w.dtype(), "afq-quantize-biases")?;
50
51        assert_eq!(w.layout().start_offset(), 0);
52        crate::metal_kernels::call_affine_quantize(
53            device.device(),
54            &command_buffer,
55            &crate::metal_kernels::Kernels::new(),
56            w.dtype(),
57            w_s.buffer(),
58            w.dims(),
59            w.stride(),
60            &output,
61            &wq_shape,
62            &scales,
63            &biases,
64            false,
65            group_size,
66            bits,
67        )
68        .map_err(candle_core::Error::wrap)?;
69
70        let output = from_storage_no_op(
71            Storage::Metal(MetalStorage::new(
72                output,
73                device.clone(),
74                wq_shape.iter().product(),
75                DType::U32,
76            )),
77            wq_shape,
78            false,
79        );
80        let scales = from_storage_no_op(
81            Storage::Metal(MetalStorage::new(
82                scales,
83                device.clone(),
84                s_shape.iter().product(),
85                w.dtype(),
86            )),
87            s_shape.clone(),
88            false,
89        );
90        let biases = from_storage_no_op(
91            Storage::Metal(MetalStorage::new(
92                biases,
93                device.clone(),
94                s_shape.iter().product(),
95                w.dtype(),
96            )),
97            s_shape,
98            false,
99        );
100
101        Ok((output, scales, biases))
102    }
103    #[cfg(not(feature = "metal"))]
104    {
105        candle_core::bail!("`afq_quantize_op` only works on Metal.")
106    }
107}
108
109pub(crate) fn afq_dequantize_op(
110    w_q: &Tensor,
111    scales: &Tensor,
112    biases: &Tensor,
113    group_size: AfqGroupSize,
114    bits: AfqBits,
115) -> Result<Tensor> {
116    let group_size = group_size as usize;
117    let bits = bits as usize;
118
119    if w_q.rank() < 2 || scales.rank() < 2 || biases.rank() < 2 {
120        candle_core::bail!("AFQ dequantize expects all matrices of at least rank 2");
121    }
122
123    #[cfg(feature = "metal")]
124    {
125        let wq_s = w_q.storage_and_layout().0;
126        let Storage::Metal(wq_s) = &*wq_s else {
127            candle_core::bail!("expected metal")
128        };
129        let s_s = scales.storage_and_layout().0;
130        let Storage::Metal(s_s) = &*s_s else {
131            candle_core::bail!("expected metal")
132        };
133        let b_s = biases.storage_and_layout().0;
134        let Storage::Metal(b_s) = &*b_s else {
135            candle_core::bail!("expected metal")
136        };
137
138        let device = wq_s.device();
139
140        let command_buffer = device.command_buffer()?;
141        command_buffer.set_label("afq-dequantize");
142
143        let out_size = w_q.dim(D::Minus1)? * 32 / bits;
144        let mut w_shape = w_q.dims().to_vec();
145        *w_shape.last_mut().unwrap() = out_size;
146
147        if out_size != scales.dim(D::Minus1)? * group_size
148            || out_size != biases.dim(D::Minus1)? * group_size
149        {
150            candle_core::bail!(
151                "Scales and biases do not match the matrix given dequantization parameters."
152            );
153        }
154
155        let output = device.new_buffer(
156            w_shape.iter().product(),
157            scales.dtype(),
158            "afq-dequantize-output",
159        )?;
160
161        assert_eq!(w_q.layout().start_offset(), 0);
162        assert_eq!(scales.layout().start_offset(), 0);
163        assert_eq!(biases.layout().start_offset(), 0);
164        crate::metal_kernels::call_affine_quantize(
165            device.device(),
166            &command_buffer,
167            &crate::metal_kernels::Kernels::new(),
168            scales.dtype(),
169            wq_s.buffer(),
170            w_q.dims(),
171            w_q.stride(),
172            &output,
173            &w_shape,
174            s_s.buffer(),
175            b_s.buffer(),
176            true,
177            group_size,
178            bits,
179        )
180        .map_err(candle_core::Error::wrap)?;
181
182        let output = from_storage_no_op(
183            Storage::Metal(MetalStorage::new(
184                output,
185                device.clone(),
186                w_shape.iter().product(),
187                scales.dtype(),
188            )),
189            w_shape,
190            false,
191        );
192
193        Ok(output)
194    }
195    #[cfg(not(feature = "metal"))]
196    {
197        candle_core::bail!("`afq_dequantize_op` only works on Metal.")
198    }
199}
200
201pub(crate) fn afq_mm_op(
202    x: &Tensor,
203    w: &Tensor,
204    scales: &Tensor,
205    biases: &Tensor,
206    group_size: AfqGroupSize,
207    bits: AfqBits,
208    transpose: bool,
209) -> Result<Tensor> {
210    let group_size = group_size as usize;
211    let bits = bits as usize;
212
213    let w_outer_dims = {
214        if w.dtype() != DType::U32 {
215            candle_core::bail!("AFQ weight matrix must be u32");
216        }
217        if scales.dims() != biases.dims() {
218            candle_core::bail!("Scales and biases should have the same shapes");
219        }
220        if w.dim(D::Minus1)? * 32 / bits != scales.dim(D::Minus1)? * group_size {
221            candle_core::bail!("Last dims of w and scales must be compatible.");
222        }
223
224        let x_inner_dims = x.dim(D::Minus1)?;
225
226        // Calculate transpose w dims
227        let w_inner_dims = if transpose {
228            w.dim(D::Minus1)? * 32 / bits
229        } else {
230            w.dim(D::Minus2)?
231        };
232        let w_outer_dims = if transpose {
233            w.dim(D::Minus2)?
234        } else {
235            w.dim(D::Minus1)? * 32 / bits
236        };
237
238        if w_inner_dims != x_inner_dims {
239            candle_core::bail!(
240                "w inner dims ({:?}) must match x inner dims ({:?}). transpose={transpose}",
241                w.dims(),
242                x.dims()
243            );
244        }
245
246        w_outer_dims
247    };
248
249    #[cfg(feature = "metal")]
250    {
251        let x_s = x.storage_and_layout().0;
252        let Storage::Metal(x_s) = &*x_s else {
253            candle_core::bail!("expected metal")
254        };
255        let w_s = w.storage_and_layout().0;
256        let Storage::Metal(w_s) = &*w_s else {
257            candle_core::bail!("expected metal")
258        };
259        let s_s = scales.storage_and_layout().0;
260        let Storage::Metal(s_s) = &*s_s else {
261            candle_core::bail!("expected metal")
262        };
263        let b_s = biases.storage_and_layout().0;
264        let Storage::Metal(b_s) = &*b_s else {
265            candle_core::bail!("expected metal")
266        };
267
268        let device = w_s.device();
269
270        let command_buffer = device.command_buffer()?;
271        command_buffer.set_label("afq-dequantize");
272
273        let mut out_shape = x.dims().to_vec();
274        *out_shape.last_mut().unwrap() = w_outer_dims;
275
276        let output =
277            device.new_buffer(out_shape.iter().product(), scales.dtype(), "afq-qmm-output")?;
278
279        assert_eq!(x.layout().start_offset(), 0);
280        assert_eq!(w.layout().start_offset(), 0);
281        assert_eq!(scales.layout().start_offset(), 0);
282        assert_eq!(biases.layout().start_offset(), 0);
283
284        crate::metal_kernels::call_afq_qmm(
285            device.device(),
286            &command_buffer,
287            &crate::metal_kernels::Kernels::new(),
288            scales.dtype(),
289            x_s.buffer(),
290            x.dims(),
291            x.stride(),
292            w_s.buffer(),
293            w.dims(),
294            w.stride(),
295            s_s.buffer(),
296            scales.stride(),
297            b_s.buffer(),
298            biases.stride(),
299            &output,
300            &out_shape,
301            transpose,
302            bits,
303            group_size,
304        )
305        .map_err(candle_core::Error::wrap)?;
306
307        let output = from_storage_no_op(
308            Storage::Metal(MetalStorage::new(
309                output,
310                device.clone(),
311                out_shape.iter().product(),
312                scales.dtype(),
313            )),
314            out_shape,
315            false,
316        );
317
318        Ok(output)
319    }
320    #[cfg(not(feature = "metal"))]
321    {
322        candle_core::bail!("`afq_mm_op` only works on Metal.")
323    }
324}
325
326#[cfg(feature = "metal")]
327#[cfg(test)]
328mod metal_tests {
329    use candle_core::{DType, Device, Result, Tensor, D};
330
331    use crate::{afq::ops::afq_dequantize_op, AfqBits, AfqGroupSize};
332
333    use super::afq_quantize_op;
334
335    fn run_afq_roundtrip(bits: AfqBits) -> Result<f32> {
336        let device = Device::new_metal(0)?;
337        let group_size = AfqGroupSize::Low;
338
339        let xs = Tensor::randn(0f32, 1f32, (32, 32), &device)?;
340
341        let (w_q, scales, biases) = afq_quantize_op(&xs, group_size, bits)?;
342
343        // println!("w_q = {w_q}");
344        // println!("scales = {scales}");
345        // println!("biases = {biases}");
346
347        let ys = afq_dequantize_op(&w_q, &scales, &biases, group_size, bits)?;
348
349        // println!("xs = {xs}");
350        // println!("ys = {ys}");
351        // println!("delta = {}", (xs - ys)?);
352
353        let rmse = (xs - ys)?
354            .sqr()?
355            .mean(D::Minus1)?
356            .sqrt()?
357            .mean_all()?
358            .to_dtype(DType::F32)?
359            .to_scalar::<f32>()?;
360
361        Ok(rmse)
362    }
363
364    #[test]
365    fn test_afq_eight() -> Result<()> {
366        let rmse = run_afq_roundtrip(AfqBits::Eight)?;
367        assert!(rmse < 0.005, "{rmse}");
368        Ok(())
369    }
370
371    #[test]
372    fn test_afq_six() -> Result<()> {
373        let rmse = run_afq_roundtrip(AfqBits::Six)?;
374        assert!(rmse < 0.02, "{rmse}");
375        Ok(())
376    }
377
378    #[test]
379    fn test_afq_four() -> Result<()> {
380        let rmse = run_afq_roundtrip(AfqBits::Four)?;
381        assert!(rmse < 0.078, "{rmse}");
382        Ok(())
383    }
384
385    #[test]
386    fn test_afq_three() -> Result<()> {
387        let rmse = run_afq_roundtrip(AfqBits::Three)?;
388        assert!(rmse < 0.17, "{rmse}");
389        Ok(())
390    }
391
392    #[test]
393    fn test_afq_two() -> Result<()> {
394        let rmse = run_afq_roundtrip(AfqBits::Two)?;
395        assert!(rmse < 0.35, "{rmse}");
396        Ok(())
397    }
398}