mistralrs_quant/utils/
isq.rs1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, Device, Result, Tensor};
4
5use crate::{
6 get_immediate_isq, should_apply_immediate_isq, ImmediateIsqParams, QuantMethod,
7 ShardedVarBuilder,
8};
9
10pub enum QuantizationBehavior {
11 Quantize(GgmlDType),
12 Skip,
13}
14
15pub fn apply_immediate_isq(
16 layer: Arc<dyn QuantMethod>,
17 vb: ShardedVarBuilder,
18) -> Result<Arc<dyn QuantMethod>> {
19 if should_apply_immediate_isq(&vb) {
20 apply_immediate_isq_always(layer, vb.device())
21 } else {
22 Ok(layer)
23 }
24}
25
26pub(crate) fn apply_immediate_isq_always(
27 layer: Arc<dyn QuantMethod>,
28 device: &Device,
29) -> Result<Arc<dyn QuantMethod>> {
30 if let Some(ImmediateIsqParams {
31 guard,
32 ty: Some(immediate_isq),
33 predicates: _,
34 }) = get_immediate_isq()
35 {
36 layer.clone().apply_isq(
37 Some(immediate_isq),
38 device.clone(),
39 &AtomicUsize::new(0),
40 None,
41 guard,
42 )
43 } else {
44 Ok(layer)
45 }
46}
47
48fn get_fallback(dtype: GgmlDType) -> QuantizationBehavior {
50 match dtype {
54 GgmlDType::Q2K => QuantizationBehavior::Quantize(GgmlDType::Q4_0),
55 GgmlDType::Q3K => QuantizationBehavior::Quantize(GgmlDType::Q4_0),
56 GgmlDType::Q4K => QuantizationBehavior::Quantize(GgmlDType::Q4_1),
57 GgmlDType::Q5K => QuantizationBehavior::Quantize(GgmlDType::Q5_0),
58 GgmlDType::Q6K => QuantizationBehavior::Quantize(GgmlDType::Q5_1),
59 GgmlDType::Q8K => QuantizationBehavior::Quantize(GgmlDType::Q8_1),
60 _ => QuantizationBehavior::Skip,
61 }
62}
63
64fn can_quantize(tensor: &Tensor, dtype: GgmlDType) -> bool {
66 let dims = tensor.shape().dims();
67 !(dims.is_empty() || (dims[dims.len() - 1] % dtype.block_size() != 0))
69}
70
71pub(crate) fn get_quantization_behaviour(
73 tensor: &Tensor,
74 dtype: GgmlDType,
75) -> QuantizationBehavior {
76 if dtype == GgmlDType::F32 {
77 return QuantizationBehavior::Skip;
78 }
79
80 if can_quantize(tensor, dtype) {
81 return QuantizationBehavior::Quantize(dtype);
82 }
83 let fallback = get_fallback(dtype);
84 match fallback {
85 QuantizationBehavior::Skip => fallback,
86 QuantizationBehavior::Quantize(new_dtype) => get_quantization_behaviour(tensor, new_dtype),
87 }
88}
89
90#[macro_export]
91#[doc(hidden)]
92macro_rules! generate_isq {
93 ($tensor:expr, $device:expr, $dtype:expr, $n_quantized:expr, $guard:expr) => {
94 {
95 let quantization_behaviour = $crate::utils::isq::get_quantization_behaviour(&$tensor, $dtype);
96 let dtype = match quantization_behaviour{
97 $crate::utils::isq::QuantizationBehavior::Skip => {
98 let shape = $tensor.shape();
99 $crate::log::once_log_warn(&format!("Skipping quantization of tensor with shape {shape:?} as it is not quantizable."));
100 GgmlDType::F32
101 },
102 $crate::utils::isq::QuantizationBehavior::Quantize(dtype) => {
103 $n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
104 dtype
105 }
106 };
107
108 let initial = candle_core::quantized::QTensor::quantize(&$tensor, dtype)?;
109 let data = initial.data()?;
110
111 let _acquired_quantize_guard = $guard.acquire(&$device);
112 let qstorage = candle_core::quantized::QStorage::from_data(data, &$device, dtype)?;
113
114 Arc::new(candle_core::quantized::QTensor::new(qstorage, $tensor.shape())?)
115 }
116 };
117}
118
119#[macro_export]
120#[doc(hidden)]
121macro_rules! generate_isq_imatrix {
122 ($tensor:expr, $imatrix:expr, $device:expr, $dtype:expr, $n_quantized:expr, $guard:expr) => {
123 {
124 let quantization_behaviour = $crate::utils::isq::get_quantization_behaviour(&$tensor, $dtype);
125 let dtype = match quantization_behaviour{
126 $crate::utils::isq::QuantizationBehavior::Skip => {
127 let shape = $tensor.shape();
128 $crate::log::once_log_warn(&format!("Skipping quantization of tensor with shape {shape:?} as it is not quantizable."));
129 GgmlDType::F32
130 },
131 $crate::utils::isq::QuantizationBehavior::Quantize(dtype) => {
132 $n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
133 dtype
134 }
135 };
136
137 let initial = candle_core::quantized::QTensor::quantize_imatrix(&$tensor, &$imatrix, dtype)?;
138 if !$tensor.device().is_cpu() {
139 Arc::new(initial)
141 } else {
142 let data = initial.data()?;
143
144 let _acquired_quantize_guard = $guard.acquire(&$device);
145 let qstorage = candle_core::quantized::QStorage::from_data(data, &$device, dtype)?;
146
147 Arc::new(candle_core::quantized::QTensor::new(qstorage, $tensor.shape())?)
148 }
149 }
150 };
151}