mistralrs_quant/utils/
isq.rs

1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, Device, Result, Tensor};
4
5use crate::{
6    get_immediate_isq, should_apply_immediate_isq, ImmediateIsqParams, QuantMethod,
7    ShardedVarBuilder,
8};
9
10pub enum QuantizationBehavior {
11    Quantize(GgmlDType),
12    Skip,
13}
14
15pub fn apply_immediate_isq(
16    layer: Arc<dyn QuantMethod>,
17    vb: ShardedVarBuilder,
18) -> Result<Arc<dyn QuantMethod>> {
19    if should_apply_immediate_isq(&vb) {
20        apply_immediate_isq_always(layer, vb.device())
21    } else {
22        Ok(layer)
23    }
24}
25
26pub(crate) fn apply_immediate_isq_always(
27    layer: Arc<dyn QuantMethod>,
28    device: &Device,
29) -> Result<Arc<dyn QuantMethod>> {
30    if let Some(ImmediateIsqParams {
31        guard,
32        ty: Some(immediate_isq),
33        predicates: _,
34    }) = get_immediate_isq()
35    {
36        layer.clone().apply_isq(
37            Some(immediate_isq),
38            device.clone(),
39            &AtomicUsize::new(0),
40            None,
41            guard,
42        )
43    } else {
44        Ok(layer)
45    }
46}
47
48/// Return the fallback dtype for the given dtype.
49fn get_fallback(dtype: GgmlDType) -> QuantizationBehavior {
50    // The normal `Q` quants are a bit more lenient than the `K` quants.
51    // => Try to fallback to a similar `Q` quant.
52    // If that's not possible, skip this tensor.
53    match dtype {
54        GgmlDType::Q2K => QuantizationBehavior::Quantize(GgmlDType::Q4_0),
55        GgmlDType::Q3K => QuantizationBehavior::Quantize(GgmlDType::Q4_0),
56        GgmlDType::Q4K => QuantizationBehavior::Quantize(GgmlDType::Q4_1),
57        GgmlDType::Q5K => QuantizationBehavior::Quantize(GgmlDType::Q5_0),
58        GgmlDType::Q6K => QuantizationBehavior::Quantize(GgmlDType::Q5_1),
59        GgmlDType::Q8K => QuantizationBehavior::Quantize(GgmlDType::Q8_1),
60        _ => QuantizationBehavior::Skip,
61    }
62}
63
64/// Check if the tensor can be quantized with the given dtype.
65fn can_quantize(tensor: &Tensor, dtype: GgmlDType) -> bool {
66    let dims = tensor.shape().dims();
67    // The tensor must not be empty and the last dimension must be a multiple of the block size.
68    !(dims.is_empty() || (dims[dims.len() - 1] % dtype.block_size() != 0))
69}
70
71/// Check if we should quantize the tensor and if so, with which dtype.
72pub(crate) fn get_quantization_behaviour(
73    tensor: &Tensor,
74    dtype: GgmlDType,
75) -> QuantizationBehavior {
76    if dtype == GgmlDType::F32 {
77        return QuantizationBehavior::Skip;
78    }
79
80    if can_quantize(tensor, dtype) {
81        return QuantizationBehavior::Quantize(dtype);
82    }
83    let fallback = get_fallback(dtype);
84    match fallback {
85        QuantizationBehavior::Skip => fallback,
86        QuantizationBehavior::Quantize(new_dtype) => get_quantization_behaviour(tensor, new_dtype),
87    }
88}
89
90#[macro_export]
91#[doc(hidden)]
92macro_rules! generate_isq {
93    ($tensor:expr, $device:expr, $dtype:expr, $n_quantized:expr, $guard:expr) => {
94        {
95            let quantization_behaviour = $crate::utils::isq::get_quantization_behaviour(&$tensor, $dtype);
96            let dtype = match quantization_behaviour{
97                $crate::utils::isq::QuantizationBehavior::Skip => {
98                    let shape = $tensor.shape();
99                    $crate::log::once_log_warn(&format!("Skipping quantization of tensor with shape {shape:?} as it is not quantizable."));
100                    GgmlDType::F32
101                },
102                $crate::utils::isq::QuantizationBehavior::Quantize(dtype) => {
103                    $n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
104                    dtype
105                }
106            };
107
108            let initial = candle_core::quantized::QTensor::quantize(&$tensor, dtype)?;
109            let data = initial.data()?;
110
111            let _acquired_quantize_guard = $guard.acquire(&$device);
112            let qstorage = candle_core::quantized::QStorage::from_data(data, &$device, dtype)?;
113
114            Arc::new(candle_core::quantized::QTensor::new(qstorage, $tensor.shape())?)
115        }
116    };
117}
118
119#[macro_export]
120#[doc(hidden)]
121macro_rules! generate_isq_imatrix {
122    ($tensor:expr, $imatrix:expr, $device:expr, $dtype:expr, $n_quantized:expr, $guard:expr) => {
123        {
124            let quantization_behaviour = $crate::utils::isq::get_quantization_behaviour(&$tensor, $dtype);
125            let dtype = match quantization_behaviour{
126                $crate::utils::isq::QuantizationBehavior::Skip => {
127                    let shape = $tensor.shape();
128                    $crate::log::once_log_warn(&format!("Skipping quantization of tensor with shape {shape:?} as it is not quantizable."));
129                    GgmlDType::F32
130                },
131                $crate::utils::isq::QuantizationBehavior::Quantize(dtype) => {
132                    $n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
133                    dtype
134                }
135            };
136
137            let initial = candle_core::quantized::QTensor::quantize_imatrix(&$tensor, &$imatrix, dtype)?;
138            if !$tensor.device().is_cpu() {
139                // Short-circuit here, no need for fancy
140                Arc::new(initial)
141            } else {
142                let data = initial.data()?;
143
144                let _acquired_quantize_guard = $guard.acquire(&$device);
145                let qstorage = candle_core::quantized::QStorage::from_data(data, &$device, dtype)?;
146
147                Arc::new(candle_core::quantized::QTensor::new(qstorage, $tensor.shape())?)
148            }
149        }
150    };
151}