mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
44pub use blockwise_fp8::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
45pub use distributed::{
46    layers::{
47        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
48        ReplicatedLayer, RowParallelLayer,
49    },
50    socket::{Client, Server},
51    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
52};
53pub use dummy::DummyLayer;
54pub use fp8::FP8Linear;
55pub use gguf::GgufMatMul;
56pub use gptq::GptqLayer;
57pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
58pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
59pub use lora::{
60    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
61    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
62};
63pub use mxfp4::MXFP4Layer;
64pub use unquantized::UnquantLinear;
65pub use utils::isq::apply_immediate_isq;
66pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
67pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
68
69use candle_nn::{Linear, Module};
70use serde::{Deserialize, Deserializer, Serialize};
71
72#[derive(Clone, Debug)]
73pub struct ImmediateIsqParams {
74    pub guard: QuantizeOntoGuard,
75    pub ty: Option<IsqType>,
76    pub predicates: Vec<Regex>,
77}
78
79thread_local! {
80    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
81}
82
83pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
84    ENGINE_IMMEDIATE_ISQ.with(|cell| {
85        *cell.borrow_mut() = Some(ImmediateIsqParams {
86            guard: QuantizeOntoGuard::new(),
87            ty: isq,
88            predicates,
89        });
90    });
91}
92
93pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
94    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
95}
96
97pub fn clear_immediate_isq() {
98    ENGINE_IMMEDIATE_ISQ.with(|cell| {
99        *cell.borrow_mut() = None;
100    });
101}
102
103pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
104    let Some(immediate_isq) = get_immediate_isq() else {
105        return false;
106    };
107    // Add a .weight to match the ISQ regexes!
108    let prefix = format!("{}.weight", vb.prefix());
109    immediate_isq.ty.is_some()
110        && immediate_isq
111            .predicates
112            .iter()
113            .any(|predicate| predicate.is_match(&prefix))
114}
115
116#[derive(Debug, Clone, Serialize)]
117#[serde(tag = "quant_method", rename_all = "lowercase")]
118pub enum QuantizedConfig {
119    GptqAwq {
120        bits: usize,
121        group_size: usize,
122        checkpoint_format: Option<String>,
123        is_awq: bool,
124    },
125    Fp8 {
126        weight_block_size: Vec<usize>,
127    },
128    Bitsandbytes {
129        bnb_4bit_quant_type: Option<String>,
130    },
131    Afq {
132        bits: usize,
133        group_size: usize,
134    },
135    MXFP4 {},
136}
137
138// Common fields for all variants
139#[derive(Deserialize)]
140struct RawConfig {
141    quant_method: Option<String>,
142    bits: Option<usize>,
143    group_size: Option<usize>,
144    checkpoint_format: Option<String>,
145    weight_block_size: Option<Vec<usize>>,
146    bnb_4bit_quant_type: Option<String>,
147}
148
149// Custom deserializer implementation
150impl<'de> Deserialize<'de> for QuantizedConfig {
151    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
152    where
153        D: Deserializer<'de>,
154    {
155        let raw = RawConfig::deserialize(deserializer)?;
156
157        match &raw.quant_method {
158            Some(m) if m == "gptq" || m == "awq" => {
159                let bits = raw
160                    .bits
161                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
162                let group_size = raw
163                    .group_size
164                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
165                Ok(QuantizedConfig::GptqAwq {
166                    bits,
167                    group_size,
168                    checkpoint_format: raw.checkpoint_format,
169                    is_awq: m == "awq",
170                })
171            }
172            Some(m) if m == "fp8" => {
173                let weight_block_size = raw
174                    .weight_block_size
175                    .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
176                Ok(QuantizedConfig::Fp8 { weight_block_size })
177            }
178            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
179                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
180            }),
181            Some(m) if m == "afq" => {
182                let bits = raw
183                    .bits
184                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
185                let group_size = raw
186                    .group_size
187                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
188                Ok(QuantizedConfig::Afq { bits, group_size })
189            }
190            Some(m) if m == "mxfp4" => {
191                Ok(QuantizedConfig::MXFP4 {  })
192            }
193            None => {
194                let bits = raw
195                    .bits
196                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
197                let group_size = raw
198                    .group_size
199                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
200                Ok(QuantizedConfig::Afq { bits, group_size })
201            }
202            Some(unknown_method) => {
203                Err(serde::de::Error::custom(format!(
204                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
205                )))
206            },
207        }
208    }
209}
210
211impl QuantizedConfig {
212    pub fn name(&self) -> &'static str {
213        match self {
214            Self::GptqAwq { .. } => "gptq",
215            Self::Fp8 { .. } => "fp8",
216            Self::Bitsandbytes { .. } => "bitsandbytes",
217            Self::Afq { .. } => "afq",
218            Self::MXFP4 { .. } => "mxfp4",
219        }
220    }
221
222    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
223        match self {
224            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
225            Self::Fp8 { .. } => "8 bits".to_string(),
226            Self::Bitsandbytes {
227                bnb_4bit_quant_type: Some(_),
228            } => "4 bits".to_string(),
229            Self::Bitsandbytes {
230                bnb_4bit_quant_type: None,
231            } => "8 bits".to_string(),
232            Self::Afq { bits, .. } => format!("{bits} bits"),
233            Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
234        }
235    }
236
237    pub fn pack_factor(&self, dtype: DType) -> usize {
238        match self {
239            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
240                2 => IsqType::Q2K.pack_factor(dtype),
241                3 => IsqType::Q3K.pack_factor(dtype),
242                4 => IsqType::Q4K.pack_factor(dtype),
243                5 => IsqType::Q5K.pack_factor(dtype),
244                6 => IsqType::Q6K.pack_factor(dtype),
245                8 => IsqType::Q8_0.pack_factor(dtype),
246                40 => 4, // mxfp4: 2 FP4 values per byte = factor of 4
247                other => panic!("Unexpected bits in `pack_factor` {other}"),
248            },
249            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
250            Self::Bitsandbytes {
251                bnb_4bit_quant_type: Some(_),
252            }
253            | Self::Bitsandbytes {
254                bnb_4bit_quant_type: None,
255            } => IsqType::Q4K.pack_factor(dtype),
256            Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
257        }
258    }
259}
260
261#[derive(Debug, Clone)]
262pub enum QuantMethodConfig {
263    GptqAwq {
264        bits: i32,
265        use_exllama: bool,
266        q_weight: Tensor,
267        qzeros: Option<Tensor>,
268        scales: Tensor,
269        g_idx: Option<Tensor>,
270        bias: Option<Tensor>,
271        workspace: Option<Tensor>,
272        is_marlin: bool,
273        is_awq: bool,
274    },
275    Gguf {
276        q_weight: Arc<QTensor>,
277        b: Option<Tensor>,
278    },
279    Unquantized(Linear),
280    Hqq {
281        tensor: Tensor,
282        bits: HqqBits,
283        group_size: NonZeroUsize,
284        axis: HqqAxis,
285        optimization_steps: Option<usize>,
286        round_zeros: Option<bool>,
287        channel_wise: Option<bool>,
288        bias: Option<Tensor>,
289    },
290    Dummy,
291    FP8 {
292        lin: Linear,
293        dtype: DType,
294    },
295    Bnb {
296        weight: Tensor,
297        bias: Option<Tensor>,
298        params: BnbQuantParmas,
299        quant_ty: BnbQuantType,
300    },
301    BlockwiseFP8 {
302        weight: Tensor,
303        weight_scale_inv: Tensor,
304        bias: Option<Tensor>,
305        dequant_dtype: DType,
306        weight_block_size: Vec<usize>,
307    },
308    Afq {
309        weight: Tensor,
310        bias: Option<Tensor>,
311        bits: AfqBits,
312        group_size: AfqGroupSize,
313    },
314    MXFP4 {
315        blocks: Tensor,
316        scales: Tensor,
317        bias: Option<Tensor>,
318    },
319}
320
321/// Device/configurable intelligent matrix multiplication
322/// - Handles limitation of `accelerate` which requires f32
323pub struct MatMul;
324
325impl MatMul {
326    /// Compute matrix-matrix product.
327    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
328        #[cfg(feature = "accelerate")]
329        {
330            let original_dtype = a.dtype();
331            a.to_dtype(DType::F32)?
332                .matmul(&b.to_dtype(DType::F32)?)?
333                .to_dtype(original_dtype)
334        }
335        #[cfg(not(feature = "accelerate"))]
336        {
337            if a.device().is_cpu() {
338                let original_dtype = a.dtype();
339                a.to_dtype(DType::F16)?
340                    .matmul(&b.to_dtype(DType::F16)?)?
341                    .to_dtype(original_dtype)
342            } else {
343                a.matmul(b)
344            }
345        }
346    }
347
348    /// Compute matrix-matrix product.
349    /// The result will be divided by the `scale` parameter in an affine division.
350    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
351        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
352        self.matmul(a, b)? / scale
353    }
354
355    /// Compute matrix-matrix product.
356    /// The result will be divided by the `scale` parameter in an affine multiplication.
357    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
358        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
359        self.matmul(a, b)? * scale
360    }
361
362    /// Compute quantized matrix-matrix product.
363    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
364        matmul.forward(x)
365    }
366
367    /// Compute quantized matrix-matrix product.
368    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
369        matmul.forward(x)
370    }
371}
372
373#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
374pub enum IsqType {
375    Q4_0,
376    Q4_1,
377    Q5_0,
378    Q5_1,
379    Q8_0,
380    Q8_1,
381    Q2K,
382    Q3K,
383    Q4K,
384    Q5K,
385    Q6K,
386    Q8K,
387    HQQ8,
388    HQQ4,
389    // HQQ3,
390    // HQQ2,
391    // HQQ1,
392    F8E4M3,
393    AFQ8,
394    AFQ6,
395    AFQ4,
396    AFQ3,
397    AFQ2,
398}
399
400impl IsqType {
401    /// Factor by which the weight size is reduced over the given dtype.
402    /// original size / pack factor = quantized size
403    pub fn pack_factor(&self, dtype: DType) -> usize {
404        match self {
405            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
406                .div_ceil(GgmlDType::Q4_0.type_size()),
407            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
408                .div_ceil(GgmlDType::Q4_1.type_size()),
409            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
410                .div_ceil(GgmlDType::Q5_0.type_size()),
411            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
412                .div_ceil(GgmlDType::Q5_1.type_size()),
413            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
414                .div_ceil(GgmlDType::Q8_0.type_size()),
415            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
416                .div_ceil(GgmlDType::Q8_1.type_size()),
417            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
418                .div_ceil(GgmlDType::Q2K.type_size()),
419            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
420                .div_ceil(GgmlDType::Q3K.type_size()),
421            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
422                .div_ceil(GgmlDType::Q4K.type_size()),
423            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
424                .div_ceil(GgmlDType::Q5K.type_size()),
425            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
426                .div_ceil(GgmlDType::Q6K.type_size()),
427            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
428                .div_ceil(GgmlDType::Q8K.type_size()),
429            // Estimates
430            Self::HQQ4 => 4,
431            Self::HQQ8 => 2,
432            Self::F8E4M3 => 2,
433        }
434    }
435
436    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
437        match self {
438            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
439            IsqType::HQQ4
440            | IsqType::HQQ8
441            | IsqType::AFQ2
442            | IsqType::AFQ3
443            | IsqType::AFQ4
444            | IsqType::AFQ6
445            | IsqType::AFQ8 => {
446                // Use 1 because our HQQ quantizes on the GPU
447                Some(1.try_into().unwrap())
448            }
449            IsqType::F8E4M3 => None,
450            IsqType::Q2K
451            | IsqType::Q3K
452            | IsqType::Q4K
453            | IsqType::Q4_0
454            | IsqType::Q4_1
455            | IsqType::Q5K
456            | IsqType::Q5_0
457            | IsqType::Q5_1
458            | IsqType::Q6K
459            | IsqType::Q8K
460            | IsqType::Q8_0
461            | IsqType::Q8_1 => None,
462        }
463    }
464}
465
466impl TryFrom<IsqType> for GgmlDType {
467    type Error = candle_core::Error;
468
469    fn try_from(value: IsqType) -> Result<Self> {
470        let tp = match value {
471            IsqType::Q2K => Self::Q2K,
472            IsqType::Q3K => Self::Q3K,
473            IsqType::Q4K => Self::Q4K,
474            IsqType::Q4_0 => Self::Q4_0,
475            IsqType::Q4_1 => Self::Q4_1,
476            IsqType::Q5K => Self::Q5K,
477            IsqType::Q5_0 => Self::Q5_0,
478            IsqType::Q5_1 => Self::Q5_1,
479            IsqType::Q6K => Self::Q6K,
480            IsqType::Q8K => Self::Q8K,
481            IsqType::Q8_0 => Self::Q8_0,
482            IsqType::Q8_1 => Self::Q8_1,
483            _ => candle_core::bail!("Expected valid GGML ISQ type."),
484        };
485        #[cfg(feature = "cuda")]
486        {
487            if !matches!(
488                tp,
489                GgmlDType::Q4_0
490                    | GgmlDType::Q4_1
491                    | GgmlDType::Q5_0
492                    | GgmlDType::Q5_1
493                    | GgmlDType::Q8_0
494                    | GgmlDType::Q2K
495                    | GgmlDType::Q3K
496                    | GgmlDType::Q4K
497                    | GgmlDType::Q5K
498                    | GgmlDType::Q6K
499            ) {
500                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
501            }
502        }
503        Ok(tp)
504    }
505}
506
507impl TryFrom<GgmlDType> for IsqType {
508    type Error = candle_core::Error;
509
510    fn try_from(value: GgmlDType) -> Result<Self> {
511        match value {
512            GgmlDType::Q2K => Ok(Self::Q2K),
513            GgmlDType::Q3K => Ok(Self::Q3K),
514            GgmlDType::Q4K => Ok(Self::Q4K),
515            GgmlDType::Q5K => Ok(Self::Q5K),
516            GgmlDType::Q6K => Ok(Self::Q6K),
517            GgmlDType::Q4_0 => Ok(Self::Q4_0),
518            GgmlDType::Q4_1 => Ok(Self::Q4_1),
519            GgmlDType::Q5_0 => Ok(Self::Q5_0),
520            GgmlDType::Q5_1 => Ok(Self::Q5_1),
521            GgmlDType::Q8_0 => Ok(Self::Q8_0),
522            GgmlDType::Q8_1 => Ok(Self::Q8_1),
523            GgmlDType::Q8K => Ok(Self::Q8K),
524            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
525                candle_core::bail!("Expected valid GGML ISQ type.")
526            }
527        }
528    }
529}
530
531#[derive(Debug, Clone, Copy)]
532pub enum QuantizedSerdeType {
533    Gguf = 0,
534    Unquant = 1,
535    Hqq = 2,
536    Fp8 = 3,
537    Afq = 4,
538}
539
540impl TryFrom<usize> for QuantizedSerdeType {
541    type Error = candle_core::Error;
542    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
543        match value {
544            0 => Ok(Self::Gguf),
545            1 => Ok(Self::Unquant),
546            2 => Ok(Self::Hqq),
547            3 => Ok(Self::Fp8),
548            4 => Ok(Self::Afq),
549            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
550        }
551    }
552}
553
554pub trait QuantizedSerde {
555    fn name(&self) -> &'static str;
556    fn isq_serde_supported(&self) -> bool {
557        false
558    }
559    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
560        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
561    }
562    fn deserialize(
563        _data: Cow<[u8]>,
564        _device: &Device,
565        _comm: &Arc<crate::Comm>,
566        _guard: QuantizeOntoGuard,
567    ) -> Result<Arc<dyn QuantMethod>>
568    where
569        Self: Sized,
570    {
571        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
572    }
573    fn deserialize_ext_bias(
574        _data: Cow<[u8]>,
575        _device: &Device,
576        _guard: QuantizeOntoGuard,
577    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
578    where
579        Self: Sized,
580    {
581        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
582    }
583    /// NOT meant for external calling
584    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
585        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
586    }
587}
588
589/// Used to gate access to quantizing onto the host device
590#[derive(Clone, Debug)]
591#[allow(unused)]
592pub struct QuantizeOntoGuard {
593    pub inner: Arc<Mutex<()>>,
594}
595
596/// Real (for Metal) and Fake (for CUDA)
597pub enum QuantizeOntoDropGuard<'a> {
598    Real(MutexGuard<'a, ()>),
599    Fake,
600}
601
602impl Default for QuantizeOntoGuard {
603    fn default() -> Self {
604        Self::new()
605    }
606}
607
608impl QuantizeOntoGuard {
609    pub fn new() -> Self {
610        QuantizeOntoGuard {
611            inner: Arc::new(Mutex::new(())),
612        }
613    }
614
615    /// Acquire the quantize drop guard to protect the critical section.
616    ///
617    /// On metal, this flushes the command buffer to avoid "A command encoder is already encoding to this command buffer"
618    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
619        #[cfg(feature = "cuda")]
620        {
621            let _ = device;
622            QuantizeOntoDropGuard::Fake
623        }
624
625        #[cfg(not(feature = "cuda"))]
626        {
627            #[cfg(feature = "metal")]
628            if let Device::Metal(dev) = device {
629                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
630                dev.flush_command_buffer()
631                    .expect("Failed to flush command buffer.");
632            }
633            #[cfg(not(feature = "metal"))]
634            let _ = device;
635
636            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
637        }
638    }
639}
640
641pub enum DistributedKind {
642    ColumnParallel,
643    RowParallel,
644    Replicated,
645}
646
647/// Quantized method for a quantized matmul.
648pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
649    fn new(method: QuantMethodConfig) -> Result<Self>
650    where
651        Self: Sized;
652
653    fn dequantize_w(&self) -> Result<Tensor>;
654
655    /// Compute matmul of `self` and `a`. `self` should contain the weights.
656    /// Automatically cast to required quantization activation type and back
657    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
658        let original_ty = a.dtype();
659        let a = if let Some(t) = self.quantized_act_type() {
660            a.to_dtype(t)?
661        } else {
662            a.clone()
663        };
664        self.forward(&a)?.to_dtype(original_ty)
665    }
666
667    /// Compute matmul of `self` and `a`. `self` should contain the weights.
668    fn forward(&self, a: &Tensor) -> Result<Tensor>;
669
670    /// Compute matmul of `self` and `a`. `self` should contain the weights.
671    /// Automatically cast to required quantization activation type and back.
672    ///
673    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
674    /// then the indices are (n_tokens, n_experts).
675    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
676        let original_ty = a.dtype();
677        let a = if let Some(t) = self.quantized_act_type() {
678            a.to_dtype(t)?
679        } else {
680            a.clone()
681        };
682        self.gather_forward(&a, indices)?.to_dtype(original_ty)
683    }
684
685    /// Compute matmul of `self` and `a`. `self` should contain the weights.
686    ///
687    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
688    /// then the indices are (n_tokens, n_experts).
689    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
690        candle_core::bail!(
691            "{} does not support `gather_forward`. Please raise an issue.",
692            self.name()
693        )
694    }
695
696    /// If a quantized method, return the activation dtype.
697    fn quantized_act_type(&self) -> Option<DType>;
698
699    /// Weight dtype and device
700    fn dtype_and_device(&self) -> (DType, Device);
701
702    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
703    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
704
705    /// If the quant is backed by a qmatmul.
706    fn apply_isq(
707        self: Arc<Self>,
708        dtype: Option<IsqType>,
709        device: Device,
710        n_quantized: &AtomicUsize,
711        imatrix_weight: Option<Vec<f32>>,
712        guard: QuantizeOntoGuard,
713    ) -> Result<Arc<dyn QuantMethod>>;
714
715    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
716        None
717    }
718
719    /// Begin tracking stats into an ImatrixLayerStats
720    fn begin_track_stats(&mut self) -> Result<()> {
721        candle_core::bail!("`{}` does not support tracking stats.", self.name())
722    }
723
724    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
725    fn end_track_stats(&self) -> Result<Tensor> {
726        candle_core::bail!("`{}` does not support tracking stats.", self.name())
727    }
728
729    fn is_distributed(&self) -> Option<DistributedKind> {
730        None
731    }
732}
733
734impl Module for dyn QuantMethod {
735    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
736        Self::forward(self, xs)
737    }
738}
739
740pub fn linear_no_bias(
741    in_dim: usize,
742    out_dim: usize,
743    config: &Option<QuantizedConfig>,
744    vb: ShardedVarBuilder,
745) -> Result<Arc<dyn QuantMethod>> {
746    let base_vb = vb.clone();
747    let vb = if should_apply_immediate_isq(&vb) {
748        vb.set_device(Device::Cpu)
749    } else {
750        vb
751    };
752
753    let layer = if let Some(quant_conf) = &config {
754        match quant_conf {
755            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
756            QuantizedConfig::Fp8 { .. } => {
757                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
758            }
759            QuantizedConfig::Bitsandbytes { .. } => {
760                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
761            }
762            QuantizedConfig::Afq { .. } => {
763                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
764            }
765            QuantizedConfig::MXFP4 {} => {
766                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
767            }
768        }
769    } else {
770        // Handle the case where the layer is dummy (no tensors)
771        if !vb.contains_tensor("weight") {
772            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
773            Arc::new(layer) as Arc<dyn QuantMethod>
774        } else {
775            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
776            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
777
778            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
779                Linear::new(weight, None),
780            ))?;
781            Arc::new(layer) as Arc<dyn QuantMethod>
782        }
783    };
784    apply_immediate_isq(layer, base_vb)
785}
786
787pub fn linear(
788    in_dim: usize,
789    out_dim: usize,
790    config: &Option<QuantizedConfig>,
791    vb: ShardedVarBuilder,
792) -> Result<Arc<dyn QuantMethod>> {
793    let base_vb = vb.clone();
794    let vb = if should_apply_immediate_isq(&vb) {
795        vb.set_device(Device::Cpu)
796    } else {
797        vb
798    };
799
800    let layer = if let Some(quant_conf) = &config {
801        match quant_conf {
802            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
803            QuantizedConfig::Fp8 { .. } => {
804                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
805            }
806            QuantizedConfig::Bitsandbytes { .. } => {
807                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
808            }
809            QuantizedConfig::Afq { .. } => {
810                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
811            }
812            QuantizedConfig::MXFP4 {} => {
813                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
814            }
815        }
816    } else {
817        // Handle the case where the layer is dummy (no tensors)
818        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
819            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
820            Arc::new(layer) as Arc<dyn QuantMethod>
821        } else {
822            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
823            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
824            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
825
826            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
827                Linear::new(weight, Some(bias)),
828            ))?;
829            Arc::new(layer) as Arc<dyn QuantMethod>
830        }
831    };
832    apply_immediate_isq(layer, base_vb)
833}
834
835pub fn linear_b(
836    in_dim: usize,
837    out_dim: usize,
838    bias: bool,
839    config: &Option<QuantizedConfig>,
840    vb: ShardedVarBuilder,
841) -> Result<Arc<dyn QuantMethod>> {
842    if bias {
843        linear(in_dim, out_dim, config, vb)
844    } else {
845        linear_no_bias(in_dim, out_dim, config, vb)
846    }
847}