mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
44pub use blockwise_fp8::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
45pub use distributed::{
46    layers::{
47        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
48        ReplicatedLayer, RowParallelLayer,
49    },
50    socket::{Client, Server},
51    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
52};
53pub use dummy::DummyLayer;
54pub use fp8::FP8Linear;
55pub use gguf::GgufMatMul;
56pub use gptq::GptqLayer;
57pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
58pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
59pub use lora::{
60    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
61    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
62};
63pub use mxfp4::MXFP4Layer;
64pub use unquantized::UnquantLinear;
65pub use utils::isq::apply_immediate_isq;
66pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
67pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
68
69use candle_nn::{Conv1d, Conv2d, Linear, Module};
70use serde::{Deserialize, Deserializer, Serialize};
71
72#[derive(Clone, Debug)]
73pub struct ImmediateIsqParams {
74    pub guard: QuantizeOntoGuard,
75    pub ty: Option<IsqType>,
76    pub predicates: Vec<Regex>,
77}
78
79thread_local! {
80    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
81}
82
83pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
84    ENGINE_IMMEDIATE_ISQ.with(|cell| {
85        *cell.borrow_mut() = Some(ImmediateIsqParams {
86            guard: QuantizeOntoGuard::new(),
87            ty: isq,
88            predicates,
89        });
90    });
91}
92
93pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
94    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
95}
96
97pub fn clear_immediate_isq() {
98    ENGINE_IMMEDIATE_ISQ.with(|cell| {
99        *cell.borrow_mut() = None;
100    });
101}
102
103pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
104    let Some(immediate_isq) = get_immediate_isq() else {
105        return false;
106    };
107    // Add a .weight to match the ISQ regexes!
108    let prefix = format!("{}.weight", vb.prefix());
109    immediate_isq.ty.is_some()
110        && immediate_isq
111            .predicates
112            .iter()
113            .any(|predicate| predicate.is_match(&prefix))
114}
115
116#[derive(Debug, Clone, Serialize)]
117#[serde(tag = "quant_method", rename_all = "lowercase")]
118pub enum QuantizedConfig {
119    GptqAwq {
120        bits: usize,
121        group_size: usize,
122        checkpoint_format: Option<String>,
123        is_awq: bool,
124    },
125    Fp8 {
126        weight_block_size: Vec<usize>,
127    },
128    Bitsandbytes {
129        bnb_4bit_quant_type: Option<String>,
130    },
131    Afq {
132        bits: usize,
133        group_size: usize,
134    },
135    MXFP4 {},
136}
137
138// Common fields for all variants
139#[derive(Deserialize)]
140struct RawConfig {
141    quant_method: Option<String>,
142    bits: Option<usize>,
143    group_size: Option<usize>,
144    checkpoint_format: Option<String>,
145    weight_block_size: Option<Vec<usize>>,
146    bnb_4bit_quant_type: Option<String>,
147}
148
149// Custom deserializer implementation
150impl<'de> Deserialize<'de> for QuantizedConfig {
151    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
152    where
153        D: Deserializer<'de>,
154    {
155        let raw = RawConfig::deserialize(deserializer)?;
156
157        match &raw.quant_method {
158            Some(m) if m == "gptq" || m == "awq" => {
159                let bits = raw
160                    .bits
161                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
162                let group_size = raw
163                    .group_size
164                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
165                Ok(QuantizedConfig::GptqAwq {
166                    bits,
167                    group_size,
168                    checkpoint_format: raw.checkpoint_format,
169                    is_awq: m == "awq",
170                })
171            }
172            Some(m) if m == "fp8" => {
173                let weight_block_size = raw
174                    .weight_block_size
175                    .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
176                Ok(QuantizedConfig::Fp8 { weight_block_size })
177            }
178            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
179                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
180            }),
181            Some(m) if m == "afq" => {
182                let bits = raw
183                    .bits
184                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
185                let group_size = raw
186                    .group_size
187                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
188                Ok(QuantizedConfig::Afq { bits, group_size })
189            }
190            Some(m) if m == "mxfp4" => {
191                Ok(QuantizedConfig::MXFP4 {  })
192            }
193            None => {
194                let bits = raw
195                    .bits
196                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
197                let group_size = raw
198                    .group_size
199                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
200                Ok(QuantizedConfig::Afq { bits, group_size })
201            }
202            Some(unknown_method) => {
203                Err(serde::de::Error::custom(format!(
204                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
205                )))
206            },
207        }
208    }
209}
210
211impl QuantizedConfig {
212    pub fn name(&self) -> &'static str {
213        match self {
214            Self::GptqAwq { .. } => "gptq",
215            Self::Fp8 { .. } => "fp8",
216            Self::Bitsandbytes { .. } => "bitsandbytes",
217            Self::Afq { .. } => "afq",
218            Self::MXFP4 { .. } => "mxfp4",
219        }
220    }
221
222    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
223        match self {
224            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
225            Self::Fp8 { .. } => "8 bits".to_string(),
226            Self::Bitsandbytes {
227                bnb_4bit_quant_type: Some(_),
228            } => "4 bits".to_string(),
229            Self::Bitsandbytes {
230                bnb_4bit_quant_type: None,
231            } => "8 bits".to_string(),
232            Self::Afq { bits, .. } => format!("{bits} bits"),
233            Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
234        }
235    }
236
237    pub fn pack_factor(&self, dtype: DType) -> usize {
238        match self {
239            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
240                2 => IsqType::Q2K.pack_factor(dtype),
241                3 => IsqType::Q3K.pack_factor(dtype),
242                4 => IsqType::Q4K.pack_factor(dtype),
243                5 => IsqType::Q5K.pack_factor(dtype),
244                6 => IsqType::Q6K.pack_factor(dtype),
245                8 => IsqType::Q8_0.pack_factor(dtype),
246                40 => 4, // mxfp4: 2 FP4 values per byte = factor of 4
247                other => panic!("Unexpected bits in `pack_factor` {other}"),
248            },
249            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
250            Self::Bitsandbytes {
251                bnb_4bit_quant_type: Some(_),
252            }
253            | Self::Bitsandbytes {
254                bnb_4bit_quant_type: None,
255            } => IsqType::Q4K.pack_factor(dtype),
256            Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
257        }
258    }
259}
260
261#[derive(Debug, Clone)]
262pub enum QuantMethodConfig {
263    GptqAwq {
264        bits: i32,
265        use_exllama: bool,
266        q_weight: Tensor,
267        qzeros: Option<Tensor>,
268        scales: Tensor,
269        g_idx: Option<Tensor>,
270        bias: Option<Tensor>,
271        workspace: Option<Tensor>,
272        is_marlin: bool,
273        is_awq: bool,
274    },
275    Gguf {
276        q_weight: Arc<QTensor>,
277        b: Option<Tensor>,
278    },
279    Unquantized(Linear),
280    Hqq {
281        tensor: Tensor,
282        bits: HqqBits,
283        group_size: NonZeroUsize,
284        axis: HqqAxis,
285        optimization_steps: Option<usize>,
286        round_zeros: Option<bool>,
287        channel_wise: Option<bool>,
288        bias: Option<Tensor>,
289    },
290    Dummy,
291    FP8 {
292        lin: Linear,
293        dtype: DType,
294    },
295    Bnb {
296        weight: Tensor,
297        bias: Option<Tensor>,
298        params: BnbQuantParams,
299        quant_ty: BnbQuantType,
300    },
301    BlockwiseFP8 {
302        weight: Tensor,
303        weight_scale_inv: Tensor,
304        bias: Option<Tensor>,
305        dequant_dtype: DType,
306        weight_block_size: Vec<usize>,
307    },
308    Afq {
309        weight: Tensor,
310        bias: Option<Tensor>,
311        bits: AfqBits,
312        group_size: AfqGroupSize,
313    },
314    MXFP4 {
315        blocks: Tensor,
316        scales: Tensor,
317        bias: Option<Tensor>,
318    },
319}
320
321/// Device/configurable intelligent matrix multiplication
322/// - Handles limitation of `accelerate` which requires f32
323pub struct MatMul;
324
325impl MatMul {
326    /// Compute matrix-matrix product.
327    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
328        #[cfg(feature = "accelerate")]
329        {
330            let original_dtype = a.dtype();
331            a.to_dtype(DType::F32)?
332                .matmul(&b.to_dtype(DType::F32)?)?
333                .to_dtype(original_dtype)
334        }
335        #[cfg(not(feature = "accelerate"))]
336        {
337            if a.device().is_cpu() {
338                let original_dtype = a.dtype();
339                a.to_dtype(DType::F16)?
340                    .matmul(&b.to_dtype(DType::F16)?)?
341                    .to_dtype(original_dtype)
342            } else {
343                a.matmul(b)
344            }
345        }
346    }
347
348    /// Compute matrix-matrix product.
349    /// The result will be divided by the `scale` parameter in an affine division.
350    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
351        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
352        self.matmul(a, b)? / scale
353    }
354
355    /// Compute matrix-matrix product.
356    /// The result will be divided by the `scale` parameter in an affine multiplication.
357    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
358        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
359        self.matmul(a, b)? * scale
360    }
361
362    /// Compute quantized matrix-matrix product.
363    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
364        matmul.forward(x)
365    }
366
367    /// Compute quantized matrix-matrix product.
368    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
369        matmul.forward(x)
370    }
371}
372
373/// Device/configurable intelligent convolution
374/// - Handles limitation of cpu which requires f32
375pub struct Convolution;
376
377impl Convolution {
378    pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
379        if x.device().is_cpu() {
380            let original_dtype = x.dtype();
381            Conv1d::new(
382                layer.weight().to_dtype(DType::F32)?,
383                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
384                *layer.config(),
385            )
386            .forward(&x.to_dtype(DType::F32)?)?
387            .to_dtype(original_dtype)
388        } else {
389            layer.forward(x)
390        }
391    }
392
393    pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
394        if x.device().is_cpu() {
395            let original_dtype = x.dtype();
396            Conv2d::new(
397                layer.weight().to_dtype(DType::F32)?,
398                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
399                *layer.config(),
400            )
401            .forward(&x.to_dtype(DType::F32)?)?
402            .to_dtype(original_dtype)
403        } else {
404            layer.forward(x)
405        }
406    }
407}
408
409#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
410pub enum IsqType {
411    Q4_0,
412    Q4_1,
413    Q5_0,
414    Q5_1,
415    Q8_0,
416    Q8_1,
417    Q2K,
418    Q3K,
419    Q4K,
420    Q5K,
421    Q6K,
422    Q8K,
423    HQQ8,
424    HQQ4,
425    // HQQ3,
426    // HQQ2,
427    // HQQ1,
428    F8E4M3,
429    AFQ8,
430    AFQ6,
431    AFQ4,
432    AFQ3,
433    AFQ2,
434}
435
436impl IsqType {
437    /// Factor by which the weight size is reduced over the given dtype.
438    /// original size / pack factor = quantized size
439    pub fn pack_factor(&self, dtype: DType) -> usize {
440        match self {
441            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
442                .div_ceil(GgmlDType::Q4_0.type_size()),
443            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
444                .div_ceil(GgmlDType::Q4_1.type_size()),
445            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
446                .div_ceil(GgmlDType::Q5_0.type_size()),
447            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
448                .div_ceil(GgmlDType::Q5_1.type_size()),
449            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
450                .div_ceil(GgmlDType::Q8_0.type_size()),
451            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
452                .div_ceil(GgmlDType::Q8_1.type_size()),
453            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
454                .div_ceil(GgmlDType::Q2K.type_size()),
455            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
456                .div_ceil(GgmlDType::Q3K.type_size()),
457            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
458                .div_ceil(GgmlDType::Q4K.type_size()),
459            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
460                .div_ceil(GgmlDType::Q5K.type_size()),
461            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
462                .div_ceil(GgmlDType::Q6K.type_size()),
463            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
464                .div_ceil(GgmlDType::Q8K.type_size()),
465            // Estimates
466            Self::HQQ4 => 4,
467            Self::HQQ8 => 2,
468            Self::F8E4M3 => 2,
469        }
470    }
471
472    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
473        match self {
474            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
475            IsqType::HQQ4
476            | IsqType::HQQ8
477            | IsqType::AFQ2
478            | IsqType::AFQ3
479            | IsqType::AFQ4
480            | IsqType::AFQ6
481            | IsqType::AFQ8 => {
482                // Use 1 because our HQQ quantizes on the GPU
483                Some(1.try_into().unwrap())
484            }
485            IsqType::F8E4M3 => None,
486            IsqType::Q2K
487            | IsqType::Q3K
488            | IsqType::Q4K
489            | IsqType::Q4_0
490            | IsqType::Q4_1
491            | IsqType::Q5K
492            | IsqType::Q5_0
493            | IsqType::Q5_1
494            | IsqType::Q6K
495            | IsqType::Q8K
496            | IsqType::Q8_0
497            | IsqType::Q8_1 => None,
498        }
499    }
500}
501
502impl TryFrom<IsqType> for GgmlDType {
503    type Error = candle_core::Error;
504
505    fn try_from(value: IsqType) -> Result<Self> {
506        let tp = match value {
507            IsqType::Q2K => Self::Q2K,
508            IsqType::Q3K => Self::Q3K,
509            IsqType::Q4K => Self::Q4K,
510            IsqType::Q4_0 => Self::Q4_0,
511            IsqType::Q4_1 => Self::Q4_1,
512            IsqType::Q5K => Self::Q5K,
513            IsqType::Q5_0 => Self::Q5_0,
514            IsqType::Q5_1 => Self::Q5_1,
515            IsqType::Q6K => Self::Q6K,
516            IsqType::Q8K => Self::Q8K,
517            IsqType::Q8_0 => Self::Q8_0,
518            IsqType::Q8_1 => Self::Q8_1,
519            _ => candle_core::bail!("Expected valid GGML ISQ type."),
520        };
521        #[cfg(feature = "cuda")]
522        {
523            if !matches!(
524                tp,
525                GgmlDType::Q4_0
526                    | GgmlDType::Q4_1
527                    | GgmlDType::Q5_0
528                    | GgmlDType::Q5_1
529                    | GgmlDType::Q8_0
530                    | GgmlDType::Q2K
531                    | GgmlDType::Q3K
532                    | GgmlDType::Q4K
533                    | GgmlDType::Q5K
534                    | GgmlDType::Q6K
535            ) {
536                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
537            }
538        }
539        Ok(tp)
540    }
541}
542
543impl TryFrom<GgmlDType> for IsqType {
544    type Error = candle_core::Error;
545
546    fn try_from(value: GgmlDType) -> Result<Self> {
547        match value {
548            GgmlDType::Q2K => Ok(Self::Q2K),
549            GgmlDType::Q3K => Ok(Self::Q3K),
550            GgmlDType::Q4K => Ok(Self::Q4K),
551            GgmlDType::Q5K => Ok(Self::Q5K),
552            GgmlDType::Q6K => Ok(Self::Q6K),
553            GgmlDType::Q4_0 => Ok(Self::Q4_0),
554            GgmlDType::Q4_1 => Ok(Self::Q4_1),
555            GgmlDType::Q5_0 => Ok(Self::Q5_0),
556            GgmlDType::Q5_1 => Ok(Self::Q5_1),
557            GgmlDType::Q8_0 => Ok(Self::Q8_0),
558            GgmlDType::Q8_1 => Ok(Self::Q8_1),
559            GgmlDType::Q8K => Ok(Self::Q8K),
560            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
561                candle_core::bail!("Expected valid GGML ISQ type.")
562            }
563        }
564    }
565}
566
567#[derive(Debug, Clone, Copy)]
568pub enum QuantizedSerdeType {
569    Gguf = 0,
570    Unquant = 1,
571    Hqq = 2,
572    Fp8 = 3,
573    Afq = 4,
574}
575
576impl TryFrom<usize> for QuantizedSerdeType {
577    type Error = candle_core::Error;
578    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
579        match value {
580            0 => Ok(Self::Gguf),
581            1 => Ok(Self::Unquant),
582            2 => Ok(Self::Hqq),
583            3 => Ok(Self::Fp8),
584            4 => Ok(Self::Afq),
585            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
586        }
587    }
588}
589
590pub trait QuantizedSerde {
591    fn name(&self) -> &'static str;
592    fn isq_serde_supported(&self) -> bool {
593        false
594    }
595    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
596        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
597    }
598    fn deserialize(
599        _data: Cow<[u8]>,
600        _device: &Device,
601        _comm: &Arc<crate::Comm>,
602        _guard: QuantizeOntoGuard,
603    ) -> Result<Arc<dyn QuantMethod>>
604    where
605        Self: Sized,
606    {
607        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
608    }
609    fn deserialize_ext_bias(
610        _data: Cow<[u8]>,
611        _device: &Device,
612        _guard: QuantizeOntoGuard,
613    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
614    where
615        Self: Sized,
616    {
617        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
618    }
619    /// NOT meant for external calling
620    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
621        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
622    }
623}
624
625/// Used to gate access to quantizing onto the host device
626#[derive(Clone, Debug)]
627#[allow(unused)]
628pub struct QuantizeOntoGuard {
629    pub inner: Arc<Mutex<()>>,
630}
631
632/// Real (for Metal) and Fake (for CUDA)
633pub enum QuantizeOntoDropGuard<'a> {
634    Real(MutexGuard<'a, ()>),
635    Fake,
636}
637
638impl Default for QuantizeOntoGuard {
639    fn default() -> Self {
640        Self::new()
641    }
642}
643
644impl QuantizeOntoGuard {
645    pub fn new() -> Self {
646        QuantizeOntoGuard {
647            inner: Arc::new(Mutex::new(())),
648        }
649    }
650
651    /// Acquire the quantize drop guard to protect the critical section.
652    ///
653    /// On metal, this flushes the command buffer to avoid "A command encoder is already encoding to this command buffer"
654    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
655        #[cfg(feature = "cuda")]
656        {
657            let _ = device;
658            QuantizeOntoDropGuard::Fake
659        }
660
661        #[cfg(not(feature = "cuda"))]
662        {
663            #[cfg(feature = "metal")]
664            if let Device::Metal(dev) = device {
665                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
666                dev.flush_command_buffer()
667                    .expect("Failed to flush command buffer.");
668            }
669            #[cfg(not(feature = "metal"))]
670            let _ = device;
671
672            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
673        }
674    }
675}
676
677pub enum DistributedKind {
678    ColumnParallel,
679    RowParallel,
680    Replicated,
681}
682
683/// Quantized method for a quantized matmul.
684pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
685    fn new(method: QuantMethodConfig) -> Result<Self>
686    where
687        Self: Sized;
688
689    fn dequantize_w(&self) -> Result<Tensor>;
690
691    /// Compute matmul of `self` and `a`. `self` should contain the weights.
692    /// Automatically cast to required quantization activation type and back
693    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
694        let original_ty = a.dtype();
695        let a = if let Some(t) = self.quantized_act_type() {
696            a.to_dtype(t)?
697        } else {
698            a.clone()
699        };
700        self.forward(&a)?.to_dtype(original_ty)
701    }
702
703    /// Compute matmul of `self` and `a`. `self` should contain the weights.
704    fn forward(&self, a: &Tensor) -> Result<Tensor>;
705
706    /// Compute matmul of `self` and `a`. `self` should contain the weights.
707    /// Automatically cast to required quantization activation type and back.
708    ///
709    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
710    /// then the indices are (n_tokens, n_experts).
711    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
712        let original_ty = a.dtype();
713        let a = if let Some(t) = self.quantized_act_type() {
714            a.to_dtype(t)?
715        } else {
716            a.clone()
717        };
718        self.gather_forward(&a, indices)?.to_dtype(original_ty)
719    }
720
721    /// Compute matmul of `self` and `a`. `self` should contain the weights.
722    ///
723    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
724    /// then the indices are (n_tokens, n_experts).
725    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
726        candle_core::bail!(
727            "{} does not support `gather_forward`. Please raise an issue.",
728            self.name()
729        )
730    }
731
732    /// If a quantized method, return the activation dtype.
733    fn quantized_act_type(&self) -> Option<DType>;
734
735    /// Weight dtype and device
736    fn dtype_and_device(&self) -> (DType, Device);
737
738    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
739    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
740
741    /// If the quant is backed by a qmatmul.
742    fn apply_isq(
743        self: Arc<Self>,
744        dtype: Option<IsqType>,
745        device: Device,
746        n_quantized: &AtomicUsize,
747        imatrix_weight: Option<Vec<f32>>,
748        guard: QuantizeOntoGuard,
749    ) -> Result<Arc<dyn QuantMethod>>;
750
751    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
752        None
753    }
754
755    /// Begin tracking stats into an ImatrixLayerStats
756    fn begin_track_stats(&mut self) -> Result<()> {
757        candle_core::bail!("`{}` does not support tracking stats.", self.name())
758    }
759
760    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
761    fn end_track_stats(&self) -> Result<Tensor> {
762        candle_core::bail!("`{}` does not support tracking stats.", self.name())
763    }
764
765    fn is_distributed(&self) -> Option<DistributedKind> {
766        None
767    }
768}
769
770impl Module for dyn QuantMethod {
771    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
772        Self::forward(self, xs)
773    }
774}
775
776pub fn linear_no_bias(
777    in_dim: usize,
778    out_dim: usize,
779    config: &Option<QuantizedConfig>,
780    vb: ShardedVarBuilder,
781) -> Result<Arc<dyn QuantMethod>> {
782    let base_vb = vb.clone();
783    let vb = if should_apply_immediate_isq(&vb) {
784        vb.set_device(Device::Cpu)
785    } else {
786        vb
787    };
788
789    let layer = if let Some(quant_conf) = &config {
790        match quant_conf {
791            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
792            QuantizedConfig::Fp8 { .. } => {
793                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
794            }
795            QuantizedConfig::Bitsandbytes { .. } => {
796                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
797            }
798            QuantizedConfig::Afq { .. } => {
799                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
800            }
801            QuantizedConfig::MXFP4 {} => {
802                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
803            }
804        }
805    } else {
806        // Handle the case where the layer is dummy (no tensors)
807        if !vb.contains_tensor("weight") {
808            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
809            Arc::new(layer) as Arc<dyn QuantMethod>
810        } else {
811            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
812            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
813
814            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
815                Linear::new(weight, None),
816            ))?;
817            Arc::new(layer) as Arc<dyn QuantMethod>
818        }
819    };
820    apply_immediate_isq(layer, base_vb)
821}
822
823pub fn linear(
824    in_dim: usize,
825    out_dim: usize,
826    config: &Option<QuantizedConfig>,
827    vb: ShardedVarBuilder,
828) -> Result<Arc<dyn QuantMethod>> {
829    let base_vb = vb.clone();
830    let vb = if should_apply_immediate_isq(&vb) {
831        vb.set_device(Device::Cpu)
832    } else {
833        vb
834    };
835
836    let layer = if let Some(quant_conf) = &config {
837        match quant_conf {
838            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
839            QuantizedConfig::Fp8 { .. } => {
840                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
841            }
842            QuantizedConfig::Bitsandbytes { .. } => {
843                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
844            }
845            QuantizedConfig::Afq { .. } => {
846                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
847            }
848            QuantizedConfig::MXFP4 {} => {
849                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
850            }
851        }
852    } else {
853        // Handle the case where the layer is dummy (no tensors)
854        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
855            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
856            Arc::new(layer) as Arc<dyn QuantMethod>
857        } else {
858            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
859            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
860            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
861
862            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
863                Linear::new(weight, Some(bias)),
864            ))?;
865            Arc::new(layer) as Arc<dyn QuantMethod>
866        }
867    };
868    apply_immediate_isq(layer, base_vb)
869}
870
871pub fn linear_b(
872    in_dim: usize,
873    out_dim: usize,
874    bias: bool,
875    config: &Option<QuantizedConfig>,
876    vb: ShardedVarBuilder,
877) -> Result<Arc<dyn QuantMethod>> {
878    if bias {
879        linear(in_dim, out_dim, config, vb)
880    } else {
881        linear_no_bias(in_dim, out_dim, config, vb)
882    }
883}