mistralrs_quant/utils/
isq.rs

1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, Device, Result, Tensor};
4
5use crate::{
6    get_immediate_isq, ImmediateIsqMatch, ImmediateIsqParams, QuantMethod, ShardedVarBuilder,
7};
8
9pub enum QuantizationBehavior {
10    Quantize(GgmlDType),
11    Skip,
12}
13
14pub fn apply_immediate_isq(
15    layer: Arc<dyn QuantMethod>,
16    vb: ShardedVarBuilder,
17) -> Result<Arc<dyn QuantMethod>> {
18    let Some(params) = get_immediate_isq() else {
19        return Ok(layer);
20    };
21    let prefix = format!("{}.weight", vb.prefix());
22    if let Some(ImmediateIsqMatch { ty, device }) = crate::resolve_immediate_isq(&params, &prefix) {
23        let device = device.unwrap_or_else(|| vb.device().clone());
24        layer.clone().apply_isq(
25            Some(ty),
26            device,
27            &AtomicUsize::new(0),
28            None,
29            params.guard.clone(),
30        )
31    } else {
32        Ok(layer)
33    }
34}
35
36pub(crate) fn apply_immediate_isq_always(
37    layer: Arc<dyn QuantMethod>,
38    device: &Device,
39) -> Result<Arc<dyn QuantMethod>> {
40    if let Some(ImmediateIsqParams {
41        guard,
42        ty: Some(immediate_isq),
43        predicates: _,
44        ..
45    }) = get_immediate_isq()
46    {
47        layer.clone().apply_isq(
48            Some(immediate_isq),
49            device.clone(),
50            &AtomicUsize::new(0),
51            None,
52            guard,
53        )
54    } else {
55        Ok(layer)
56    }
57}
58
59/// Return the fallback dtype for the given dtype.
60fn get_fallback(dtype: GgmlDType) -> QuantizationBehavior {
61    // The normal `Q` quants are a bit more lenient than the `K` quants.
62    // => Try to fallback to a similar `Q` quant.
63    // If that's not possible, skip this tensor.
64    match dtype {
65        GgmlDType::Q2K => QuantizationBehavior::Quantize(GgmlDType::Q4_0),
66        GgmlDType::Q3K => QuantizationBehavior::Quantize(GgmlDType::Q4_0),
67        GgmlDType::Q4K => QuantizationBehavior::Quantize(GgmlDType::Q4_1),
68        GgmlDType::Q5K => QuantizationBehavior::Quantize(GgmlDType::Q5_0),
69        GgmlDType::Q6K => QuantizationBehavior::Quantize(GgmlDType::Q5_1),
70        GgmlDType::Q8K => QuantizationBehavior::Quantize(GgmlDType::Q8_1),
71        _ => QuantizationBehavior::Skip,
72    }
73}
74
75/// Check if the tensor can be quantized with the given dtype.
76fn can_quantize(tensor: &Tensor, dtype: GgmlDType) -> bool {
77    let dims = tensor.shape().dims();
78    // The tensor must not be empty and the last dimension must be a multiple of the block size.
79    !(dims.is_empty() || (dims[dims.len() - 1] % dtype.block_size() != 0))
80}
81
82/// Check if we should quantize the tensor and if so, with which dtype.
83pub(crate) fn get_quantization_behaviour(
84    tensor: &Tensor,
85    dtype: GgmlDType,
86) -> QuantizationBehavior {
87    if dtype == GgmlDType::F32 {
88        return QuantizationBehavior::Skip;
89    }
90
91    if can_quantize(tensor, dtype) {
92        return QuantizationBehavior::Quantize(dtype);
93    }
94    let fallback = get_fallback(dtype);
95    match fallback {
96        QuantizationBehavior::Skip => fallback,
97        QuantizationBehavior::Quantize(new_dtype) => get_quantization_behaviour(tensor, new_dtype),
98    }
99}
100
101#[macro_export]
102#[doc(hidden)]
103macro_rules! generate_isq {
104    ($tensor:expr, $device:expr, $dtype:expr, $n_quantized:expr, $guard:expr) => {
105        {
106            let quantization_behaviour = $crate::utils::isq::get_quantization_behaviour(&$tensor, $dtype);
107            let dtype = match quantization_behaviour{
108                $crate::utils::isq::QuantizationBehavior::Skip => {
109                    let shape = $tensor.shape();
110                    $crate::log::once_log_warn(&format!("Skipping quantization of tensor with shape {shape:?} as it is not quantizable."));
111                    GgmlDType::F32
112                },
113                $crate::utils::isq::QuantizationBehavior::Quantize(dtype) => {
114                    $n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
115                    dtype
116                }
117            };
118
119            let initial = candle_core::quantized::QTensor::quantize(&$tensor, dtype)?;
120            let data = initial.data()?;
121
122            let _acquired_quantize_guard = $guard.acquire(&$device);
123            let qstorage = candle_core::quantized::QStorage::from_data(data, &$device, dtype)?;
124
125            Arc::new(candle_core::quantized::QTensor::new(qstorage, $tensor.shape())?)
126        }
127    };
128}
129
130#[macro_export]
131#[doc(hidden)]
132macro_rules! generate_isq_imatrix {
133    ($tensor:expr, $imatrix:expr, $device:expr, $dtype:expr, $n_quantized:expr, $guard:expr) => {
134        {
135            let quantization_behaviour = $crate::utils::isq::get_quantization_behaviour(&$tensor, $dtype);
136            let dtype = match quantization_behaviour{
137                $crate::utils::isq::QuantizationBehavior::Skip => {
138                    let shape = $tensor.shape();
139                    $crate::log::once_log_warn(&format!("Skipping quantization of tensor with shape {shape:?} as it is not quantizable."));
140                    GgmlDType::F32
141                },
142                $crate::utils::isq::QuantizationBehavior::Quantize(dtype) => {
143                    $n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
144                    dtype
145                }
146            };
147
148            let initial = candle_core::quantized::QTensor::quantize_imatrix(&$tensor, &$imatrix, dtype)?;
149            if !$tensor.device().is_cpu() {
150                // Short-circuit here, no need for fancy
151                Arc::new(initial)
152            } else {
153                let data = initial.data()?;
154
155                let _acquired_quantize_guard = $guard.acquire(&$device);
156                let qstorage = candle_core::quantized::QStorage::from_data(data, &$device, dtype)?;
157
158                Arc::new(candle_core::quantized::QTensor::new(qstorage, $tensor.shape())?)
159            }
160        }
161    };
162}