mistralrs_quant/afq/
ops.rs

1#![allow(unused)]
2
3use candle_core::{
4    backend::BackendStorage, from_storage_no_op, DType, MetalStorage, Result, Shape, Storage,
5    Tensor, D,
6};
7
8use super::{AfqBits, AfqGroupSize};
9
10/// Returns (w_q, scales, biases)
11pub(crate) fn afq_quantize_op(
12    w: &Tensor,
13    group_size: AfqGroupSize,
14    bits: AfqBits,
15) -> Result<(Tensor, Tensor, Tensor)> {
16    let group_size = group_size as usize;
17    let bits = bits as usize;
18
19    if w.rank() < 2 {
20        candle_core::bail!("AFQ quantize expects weight matrix of at least rank 2");
21    }
22    if w.dim(D::Minus1)? % group_size != 0 {
23        candle_core::bail!(
24            "Last dim of weight matrix ({:?}) must be divisible by group size {group_size}.",
25            w.dims()
26        );
27    }
28
29    #[cfg(feature = "metal")]
30    {
31        let w_s = w.storage_and_layout().0;
32        let Storage::Metal(w_s) = &*w_s else {
33            candle_core::bail!("expected metal")
34        };
35        let device = w_s.device();
36
37        let command_buffer = device.command_buffer()?;
38        command_buffer.set_label("afq-quantize");
39
40        let mut wq_shape = w.dims().to_vec();
41        *wq_shape.last_mut().unwrap() = w.dim(D::Minus1)? * bits / 32;
42        let mut s_shape = w.dims().to_vec();
43        *s_shape.last_mut().unwrap() = w.dim(D::Minus1)? / group_size;
44
45        let output =
46            device.new_buffer(wq_shape.iter().product(), DType::U32, "afq-quantize-output")?;
47        let scales =
48            device.new_buffer(s_shape.iter().product(), w.dtype(), "afq-quantize-scales")?;
49        let biases =
50            device.new_buffer(s_shape.iter().product(), w.dtype(), "afq-quantize-biases")?;
51
52        assert_eq!(w.layout().start_offset(), 0);
53        crate::metal_kernels::call_affine_quantize(
54            device.device(),
55            &command_buffer,
56            &crate::metal_kernels::Kernels::new(),
57            w.dtype(),
58            w_s.buffer(),
59            w.dims(),
60            w.stride(),
61            &output,
62            &wq_shape,
63            &scales,
64            &biases,
65            false,
66            group_size,
67            bits,
68        )
69        .map_err(candle_core::Error::wrap)?;
70
71        let output = from_storage_no_op(
72            Storage::Metal(MetalStorage::new(
73                output,
74                device.clone(),
75                wq_shape.iter().product(),
76                DType::U32,
77            )),
78            wq_shape,
79            false,
80        );
81        let scales = from_storage_no_op(
82            Storage::Metal(MetalStorage::new(
83                scales,
84                device.clone(),
85                s_shape.iter().product(),
86                w.dtype(),
87            )),
88            s_shape.clone(),
89            false,
90        );
91        let biases = from_storage_no_op(
92            Storage::Metal(MetalStorage::new(
93                biases,
94                device.clone(),
95                s_shape.iter().product(),
96                w.dtype(),
97            )),
98            s_shape,
99            false,
100        );
101
102        Ok((output, scales, biases))
103    }
104    #[cfg(not(feature = "metal"))]
105    {
106        candle_core::bail!("`afq_quantize_op` only works on Metal.")
107    }
108}
109
110pub(crate) fn afq_dequantize_op(
111    w_q: &Tensor,
112    scales: &Tensor,
113    biases: &Tensor,
114    group_size: AfqGroupSize,
115    bits: AfqBits,
116) -> Result<Tensor> {
117    let group_size = group_size as usize;
118    let bits = bits as usize;
119
120    if w_q.rank() < 2 || scales.rank() < 2 || biases.rank() < 2 {
121        candle_core::bail!("AFQ dequantize expects all matrices of at least rank 2");
122    }
123
124    #[cfg(feature = "metal")]
125    {
126        let wq_s = w_q.storage_and_layout().0;
127        let Storage::Metal(wq_s) = &*wq_s else {
128            candle_core::bail!("expected metal")
129        };
130        let s_s = scales.storage_and_layout().0;
131        let Storage::Metal(s_s) = &*s_s else {
132            candle_core::bail!("expected metal")
133        };
134        let b_s = biases.storage_and_layout().0;
135        let Storage::Metal(b_s) = &*b_s else {
136            candle_core::bail!("expected metal")
137        };
138
139        let device = wq_s.device();
140
141        let command_buffer = device.command_buffer()?;
142        command_buffer.set_label("afq-dequantize");
143
144        let out_size = w_q.dim(D::Minus1)? * 32 / bits;
145        let mut w_shape = w_q.dims().to_vec();
146        *w_shape.last_mut().unwrap() = out_size;
147
148        if out_size != scales.dim(D::Minus1)? * group_size
149            || out_size != biases.dim(D::Minus1)? * group_size
150        {
151            candle_core::bail!(
152                "Scales and biases do not match the matrix given dequantization parameters."
153            );
154        }
155
156        let output = device.new_buffer(
157            w_shape.iter().product(),
158            scales.dtype(),
159            "afq-dequantize-output",
160        )?;
161
162        assert_eq!(w_q.layout().start_offset(), 0);
163        assert_eq!(scales.layout().start_offset(), 0);
164        assert_eq!(biases.layout().start_offset(), 0);
165        crate::metal_kernels::call_affine_quantize(
166            device.device(),
167            &command_buffer,
168            &crate::metal_kernels::Kernels::new(),
169            scales.dtype(),
170            wq_s.buffer(),
171            w_q.dims(),
172            w_q.stride(),
173            &output,
174            &w_shape,
175            s_s.buffer(),
176            b_s.buffer(),
177            true,
178            group_size,
179            bits,
180        )
181        .map_err(candle_core::Error::wrap)?;
182
183        let output = from_storage_no_op(
184            Storage::Metal(MetalStorage::new(
185                output,
186                device.clone(),
187                w_shape.iter().product(),
188                scales.dtype(),
189            )),
190            w_shape,
191            false,
192        );
193
194        Ok(output)
195    }
196    #[cfg(not(feature = "metal"))]
197    {
198        candle_core::bail!("`afq_dequantize_op` only works on Metal.")
199    }
200}
201
202fn make_dummy_indices(x: &Tensor) -> Result<Tensor> {
203    let x_batches = x
204        .dims()
205        .iter()
206        .take(x.rank() - 2)
207        .copied()
208        .collect::<Vec<_>>();
209
210    (Tensor::ones(x_batches.iter().product::<usize>(), DType::F32, x.device())?
211        .cumsum(0)?
212        .to_dtype(DType::U32)?
213        - 1.)?
214        .reshape(x_batches)
215}
216
217/// The indices lhs_indices and rhs_indices contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of a and b respectively.
218#[allow(clippy::too_many_arguments)]
219pub(crate) fn afq_mm_op(
220    x: &Tensor,
221    w: &Tensor,
222    scales: &Tensor,
223    biases: &Tensor,
224    lhs_indices: Option<&Tensor>,
225    rhs_indices: Option<&Tensor>,
226    group_size: AfqGroupSize,
227    bits: AfqBits,
228    transpose: bool,
229) -> Result<Tensor> {
230    let group_size = group_size as usize;
231    let bits = bits as usize;
232
233    let w_outer_dims = {
234        if w.dtype() != DType::U32 {
235            candle_core::bail!("AFQ weight matrix must be u32");
236        }
237        if scales.dims() != biases.dims() {
238            candle_core::bail!("Scales and biases should have the same shapes");
239        }
240        if w.dim(D::Minus1)? * 32 / bits != scales.dim(D::Minus1)? * group_size {
241            candle_core::bail!("Last dims of w and scales must be compatible.");
242        }
243
244        let x_inner_dims = x.dim(D::Minus1)?;
245
246        // Calculate transpose w dims
247        let w_inner_dims = if transpose {
248            w.dim(D::Minus1)? * 32 / bits
249        } else {
250            w.dim(D::Minus2)?
251        };
252        let w_outer_dims = if transpose {
253            w.dim(D::Minus2)?
254        } else {
255            w.dim(D::Minus1)? * 32 / bits
256        };
257
258        if w_inner_dims != x_inner_dims {
259            candle_core::bail!(
260                "w inner dims ({:?}) must match x inner dims ({:?}). transpose={transpose}",
261                w.dims(),
262                x.dims()
263            );
264        }
265
266        w_outer_dims
267    };
268
269    #[cfg(feature = "metal")]
270    {
271        let x_s = x.storage_and_layout().0;
272        let Storage::Metal(x_s) = &*x_s else {
273            candle_core::bail!("expected metal")
274        };
275        let w_s = w.storage_and_layout().0;
276        let Storage::Metal(w_s) = &*w_s else {
277            candle_core::bail!("expected metal")
278        };
279        let s_s = scales.storage_and_layout().0;
280        let Storage::Metal(s_s) = &*s_s else {
281            candle_core::bail!("expected metal")
282        };
283        let b_s = biases.storage_and_layout().0;
284        let Storage::Metal(b_s) = &*b_s else {
285            candle_core::bail!("expected metal")
286        };
287
288        let device = w_s.device();
289
290        let command_buffer = device.command_buffer()?;
291        command_buffer.set_label("afq-qmm");
292
293        assert_eq!(x.layout().start_offset(), 0);
294        assert_eq!(w.layout().start_offset(), 0);
295        assert_eq!(scales.layout().start_offset(), 0);
296        assert_eq!(biases.layout().start_offset(), 0);
297
298        let (output, out_shape) = if lhs_indices.is_some() || rhs_indices.is_some() {
299            let mut lhs_indices = match lhs_indices {
300                Some(lhs_indices) => lhs_indices.clone(),
301                None => make_dummy_indices(x)?,
302            };
303            let mut rhs_indices = match rhs_indices {
304                Some(rhs_indices) => rhs_indices.clone(),
305                None => make_dummy_indices(w)?,
306            };
307            if lhs_indices.dtype() != DType::U32 || rhs_indices.dtype() != DType::U32 {
308                candle_core::bail!("lhs and rhs indices must be u32.")
309            }
310            // Broadcast the indices if applicable.
311            {
312                let mut shape = lhs_indices.shape().clone();
313                shape = rhs_indices
314                    .shape()
315                    .broadcast_shape_binary_op(rhs_indices.shape(), "afq-qmm")?;
316                lhs_indices = lhs_indices.broadcast_as(shape.clone())?;
317                rhs_indices = rhs_indices.broadcast_as(shape)?;
318            }
319
320            let li_s = lhs_indices.storage_and_layout().0;
321            let Storage::Metal(li_s) = &*li_s else {
322                candle_core::bail!("expected metal")
323            };
324            let ri_s = rhs_indices.storage_and_layout().0;
325            let Storage::Metal(ri_s) = &*ri_s else {
326                candle_core::bail!("expected metal")
327            };
328
329            let mut out_shape = lhs_indices.dims().to_vec();
330            out_shape.push(x.dim(D::Minus2)?);
331            out_shape.push(w_outer_dims);
332
333            let output =
334                device.new_buffer(out_shape.iter().product(), scales.dtype(), "afq-qmm-output")?;
335
336            crate::metal_kernels::call_afq_qmm(
337                device.device(),
338                &command_buffer,
339                &crate::metal_kernels::Kernels::new(),
340                scales.dtype(),
341                x_s.buffer(),
342                x.dims(),
343                x.stride(),
344                w_s.buffer(),
345                w.dims(),
346                w.stride(),
347                s_s.buffer(),
348                scales.stride(),
349                b_s.buffer(),
350                biases.stride(),
351                &output,
352                &out_shape,
353                Some((li_s.buffer(), ri_s.buffer())),
354                Some(lhs_indices.dims()),
355                Some((lhs_indices.stride(), rhs_indices.stride())),
356                transpose,
357                bits,
358                group_size,
359            )
360            .map_err(candle_core::Error::wrap)?;
361
362            (output, out_shape)
363        } else {
364            let mut out_shape = x.dims().to_vec();
365            *out_shape.last_mut().unwrap() = w_outer_dims;
366
367            let output =
368                device.new_buffer(out_shape.iter().product(), scales.dtype(), "afq-qmm-output")?;
369
370            crate::metal_kernels::call_afq_qmm(
371                device.device(),
372                &command_buffer,
373                &crate::metal_kernels::Kernels::new(),
374                scales.dtype(),
375                x_s.buffer(),
376                x.dims(),
377                x.stride(),
378                w_s.buffer(),
379                w.dims(),
380                w.stride(),
381                s_s.buffer(),
382                scales.stride(),
383                b_s.buffer(),
384                biases.stride(),
385                &output,
386                &out_shape,
387                None,
388                None,
389                None,
390                transpose,
391                bits,
392                group_size,
393            )
394            .map_err(candle_core::Error::wrap)?;
395
396            (output, out_shape)
397        };
398
399        let output = from_storage_no_op(
400            Storage::Metal(MetalStorage::new(
401                output,
402                device.clone(),
403                out_shape.iter().product(),
404                scales.dtype(),
405            )),
406            out_shape,
407            false,
408        );
409
410        Ok(output)
411    }
412    #[cfg(not(feature = "metal"))]
413    {
414        candle_core::bail!("`afq_mm_op` only works on Metal.")
415    }
416}
417
418#[cfg(feature = "metal")]
419#[cfg(test)]
420mod metal_tests {
421    use candle_core::{DType, Device, Result, Tensor, D};
422
423    use crate::{afq::ops::afq_dequantize_op, AfqBits, AfqGroupSize};
424
425    use super::afq_quantize_op;
426
427    fn run_afq_roundtrip(bits: AfqBits) -> Result<f32> {
428        let device = Device::new_metal(0)?;
429        let group_size = AfqGroupSize::Low;
430
431        let xs = Tensor::randn(0f32, 1f32, (32, 32), &device)?;
432
433        let (w_q, scales, biases) = afq_quantize_op(&xs, group_size, bits)?;
434
435        // println!("w_q = {w_q}");
436        // println!("scales = {scales}");
437        // println!("biases = {biases}");
438
439        let ys = afq_dequantize_op(&w_q, &scales, &biases, group_size, bits)?;
440
441        // println!("xs = {xs}");
442        // println!("ys = {ys}");
443        // println!("delta = {}", (xs - ys)?);
444
445        let rmse = (xs - ys)?
446            .sqr()?
447            .mean(D::Minus1)?
448            .sqrt()?
449            .mean_all()?
450            .to_dtype(DType::F32)?
451            .to_scalar::<f32>()?;
452
453        Ok(rmse)
454    }
455
456    #[test]
457    fn test_afq_eight() -> Result<()> {
458        let rmse = run_afq_roundtrip(AfqBits::Eight)?;
459        assert!(rmse < 0.005, "{rmse}");
460        Ok(())
461    }
462
463    #[test]
464    fn test_afq_six() -> Result<()> {
465        let rmse = run_afq_roundtrip(AfqBits::Six)?;
466        assert!(rmse < 0.02, "{rmse}");
467        Ok(())
468    }
469
470    #[test]
471    fn test_afq_four() -> Result<()> {
472        let rmse = run_afq_roundtrip(AfqBits::Four)?;
473        assert!(rmse < 0.078, "{rmse}");
474        Ok(())
475    }
476
477    #[test]
478    fn test_afq_three() -> Result<()> {
479        let rmse = run_afq_roundtrip(AfqBits::Three)?;
480        assert!(rmse < 0.17, "{rmse}");
481        Ok(())
482    }
483
484    #[test]
485    fn test_afq_two() -> Result<()> {
486        let rmse = run_afq_roundtrip(AfqBits::Two)?;
487        assert!(rmse < 0.35, "{rmse}");
488        Ok(())
489    }
490}