mistralrs_quant/utils/
isq.rs1use std::sync::{atomic::AtomicUsize, Arc};
2
3use candle_core::{quantized::GgmlDType, Device, Result, Tensor};
4
5use crate::{
6 get_immediate_isq, ImmediateIsqMatch, ImmediateIsqParams, QuantMethod, ShardedVarBuilder,
7};
8
9pub enum QuantizationBehavior {
10 Quantize(GgmlDType),
11 Skip,
12}
13
14pub fn apply_immediate_isq(
15 layer: Arc<dyn QuantMethod>,
16 vb: ShardedVarBuilder,
17) -> Result<Arc<dyn QuantMethod>> {
18 let Some(params) = get_immediate_isq() else {
19 return Ok(layer);
20 };
21 let prefix = format!("{}.weight", vb.prefix());
22 if let Some(ImmediateIsqMatch { ty, device }) = crate::resolve_immediate_isq(¶ms, &prefix) {
23 let device = device.unwrap_or_else(|| vb.device().clone());
24 layer.clone().apply_isq(
25 Some(ty),
26 device,
27 &AtomicUsize::new(0),
28 None,
29 params.guard.clone(),
30 )
31 } else {
32 Ok(layer)
33 }
34}
35
36pub(crate) fn apply_immediate_isq_always(
37 layer: Arc<dyn QuantMethod>,
38 device: &Device,
39) -> Result<Arc<dyn QuantMethod>> {
40 if let Some(ImmediateIsqParams {
41 guard,
42 ty: Some(immediate_isq),
43 predicates: _,
44 ..
45 }) = get_immediate_isq()
46 {
47 layer.clone().apply_isq(
48 Some(immediate_isq),
49 device.clone(),
50 &AtomicUsize::new(0),
51 None,
52 guard,
53 )
54 } else {
55 Ok(layer)
56 }
57}
58
59fn get_fallback(dtype: GgmlDType) -> QuantizationBehavior {
61 match dtype {
65 GgmlDType::Q2K => QuantizationBehavior::Quantize(GgmlDType::Q4_0),
66 GgmlDType::Q3K => QuantizationBehavior::Quantize(GgmlDType::Q4_0),
67 GgmlDType::Q4K => QuantizationBehavior::Quantize(GgmlDType::Q4_1),
68 GgmlDType::Q5K => QuantizationBehavior::Quantize(GgmlDType::Q5_0),
69 GgmlDType::Q6K => QuantizationBehavior::Quantize(GgmlDType::Q5_1),
70 GgmlDType::Q8K => QuantizationBehavior::Quantize(GgmlDType::Q8_1),
71 _ => QuantizationBehavior::Skip,
72 }
73}
74
75fn can_quantize(tensor: &Tensor, dtype: GgmlDType) -> bool {
77 let dims = tensor.shape().dims();
78 !(dims.is_empty() || (dims[dims.len() - 1] % dtype.block_size() != 0))
80}
81
82pub(crate) fn get_quantization_behaviour(
84 tensor: &Tensor,
85 dtype: GgmlDType,
86) -> QuantizationBehavior {
87 if dtype == GgmlDType::F32 {
88 return QuantizationBehavior::Skip;
89 }
90
91 if can_quantize(tensor, dtype) {
92 return QuantizationBehavior::Quantize(dtype);
93 }
94 let fallback = get_fallback(dtype);
95 match fallback {
96 QuantizationBehavior::Skip => fallback,
97 QuantizationBehavior::Quantize(new_dtype) => get_quantization_behaviour(tensor, new_dtype),
98 }
99}
100
101#[macro_export]
102#[doc(hidden)]
103macro_rules! generate_isq {
104 ($tensor:expr, $device:expr, $dtype:expr, $n_quantized:expr, $guard:expr) => {
105 {
106 let quantization_behaviour = $crate::utils::isq::get_quantization_behaviour(&$tensor, $dtype);
107 let dtype = match quantization_behaviour{
108 $crate::utils::isq::QuantizationBehavior::Skip => {
109 let shape = $tensor.shape();
110 $crate::log::once_log_warn(&format!("Skipping quantization of tensor with shape {shape:?} as it is not quantizable."));
111 GgmlDType::F32
112 },
113 $crate::utils::isq::QuantizationBehavior::Quantize(dtype) => {
114 $n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
115 dtype
116 }
117 };
118
119 let initial = candle_core::quantized::QTensor::quantize(&$tensor, dtype)?;
120 let data = initial.data()?;
121
122 let _acquired_quantize_guard = $guard.acquire(&$device);
123 let qstorage = candle_core::quantized::QStorage::from_data(data, &$device, dtype)?;
124
125 Arc::new(candle_core::quantized::QTensor::new(qstorage, $tensor.shape())?)
126 }
127 };
128}
129
130#[macro_export]
131#[doc(hidden)]
132macro_rules! generate_isq_imatrix {
133 ($tensor:expr, $imatrix:expr, $device:expr, $dtype:expr, $n_quantized:expr, $guard:expr) => {
134 {
135 let quantization_behaviour = $crate::utils::isq::get_quantization_behaviour(&$tensor, $dtype);
136 let dtype = match quantization_behaviour{
137 $crate::utils::isq::QuantizationBehavior::Skip => {
138 let shape = $tensor.shape();
139 $crate::log::once_log_warn(&format!("Skipping quantization of tensor with shape {shape:?} as it is not quantizable."));
140 GgmlDType::F32
141 },
142 $crate::utils::isq::QuantizationBehavior::Quantize(dtype) => {
143 $n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
144 dtype
145 }
146 };
147
148 let initial = candle_core::quantized::QTensor::quantize_imatrix(&$tensor, &$imatrix, dtype)?;
149 if !$tensor.device().is_cpu() {
150 Arc::new(initial)
152 } else {
153 let data = initial.data()?;
154
155 let _acquired_quantize_guard = $guard.acquire(&$device);
156 let qstorage = candle_core::quantized::QStorage::from_data(data, &$device, dtype)?;
157
158 Arc::new(candle_core::quantized::QTensor::new(qstorage, $tensor.shape())?)
159 }
160 }
161 };
162}