mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24pub mod gemv;
25mod gguf;
26mod gptq;
27mod hqq;
28mod imatrix;
29mod lora;
30mod mxfp4;
31pub mod rotary;
32pub mod safetensors;
33mod scalar_fp8;
34mod unquantized;
35mod utils;
36mod vector_fp8;
37
38use gptq::gptq_linear;
39use lora::merge_lora_weights;
40use regex::Regex;
41pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
42
43pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
44pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
45pub use blockwise_fp8::{
46    blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
47};
48pub use distributed::{
49    layers::{
50        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
51        ReplicatedLayer, RowParallelLayer,
52    },
53    socket::{Client, Server},
54    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
55};
56pub use dummy::DummyLayer;
57pub use fp8::FP8Linear;
58#[cfg(feature = "cuda")]
59pub use gemv::gemv;
60pub use gemv::{should_use_gemv, GEMV_CONTROLLER};
61pub use gguf::GgufMatMul;
62pub use gptq::GptqLayer;
63pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
64pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
65pub use lora::{
66    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
67    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
68};
69pub use mxfp4::MXFP4Layer;
70pub use unquantized::UnquantLinear;
71#[cfg(feature = "cuda")]
72pub use utils::gptoss_swiglu_fused;
73#[cfg(feature = "cuda")]
74pub use utils::gptoss_swiglu_interleaved;
75pub use utils::isq::apply_immediate_isq;
76#[cfg(feature = "cuda")]
77pub use utils::softmax_with_sinks;
78pub use utils::{fused_glu, GluActivationType};
79pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
80pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
81
82use candle_nn::{Conv1d, Conv2d, Linear, Module};
83use serde::{Deserialize, Deserializer, Serialize};
84
85#[derive(Clone, Debug)]
86pub struct ImmediateIsqParams {
87    pub guard: QuantizeOntoGuard,
88    pub ty: Option<IsqType>,
89    pub predicates: Vec<Regex>,
90    pub overrides: Vec<ImmediateIsqOverride>,
91}
92
93#[derive(Clone, Debug)]
94pub struct ImmediateIsqOverride {
95    pub predicate: Regex,
96    pub ty: Option<IsqType>,
97    pub device: Option<Device>,
98}
99
100#[derive(Clone, Debug)]
101pub struct ImmediateIsqMatch {
102    pub ty: IsqType,
103    pub device: Option<Device>,
104}
105
106thread_local! {
107    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
108}
109
110pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
111    set_immediate_isq_with_overrides(isq, predicates, Vec::new());
112}
113
114pub fn set_immediate_isq_with_overrides(
115    isq: Option<IsqType>,
116    predicates: Vec<Regex>,
117    overrides: Vec<ImmediateIsqOverride>,
118) {
119    ENGINE_IMMEDIATE_ISQ.with(|cell| {
120        *cell.borrow_mut() = Some(ImmediateIsqParams {
121            guard: QuantizeOntoGuard::new(),
122            ty: isq,
123            predicates,
124            overrides,
125        });
126    });
127}
128
129pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
130    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
131}
132
133pub fn clear_immediate_isq() {
134    ENGINE_IMMEDIATE_ISQ.with(|cell| {
135        *cell.borrow_mut() = None;
136    });
137}
138
139pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
140    immediate_isq_match(vb).is_some()
141}
142
143pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
144    let immediate_isq = get_immediate_isq()?;
145    // Add a .weight to match the ISQ regexes!
146    let prefix = format!("{}.weight", vb.prefix());
147    resolve_immediate_isq(&immediate_isq, &prefix)
148}
149
150fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
151    if let Some(override_hit) = params
152        .overrides
153        .iter()
154        .find(|override_pred| override_pred.predicate.is_match(prefix))
155    {
156        if let Some(ty) = override_hit.ty.or(params.ty) {
157            return Some(ImmediateIsqMatch {
158                ty,
159                device: override_hit.device.clone(),
160            });
161        }
162        return None;
163    }
164
165    if let Some(ty) = params.ty {
166        if params
167            .predicates
168            .iter()
169            .any(|predicate| predicate.is_match(prefix))
170        {
171            return Some(ImmediateIsqMatch { ty, device: None });
172        }
173    }
174
175    None
176}
177
178#[derive(Debug, Clone, Serialize)]
179#[serde(tag = "quant_method", rename_all = "lowercase")]
180pub enum QuantizedConfig {
181    GptqAwq {
182        bits: usize,
183        group_size: usize,
184        checkpoint_format: Option<String>,
185        is_awq: bool,
186    },
187    Fp8 {
188        weight_block_size: Vec<usize>,
189    },
190    Bitsandbytes {
191        bnb_4bit_quant_type: Option<String>,
192    },
193    Afq {
194        bits: usize,
195        group_size: usize,
196    },
197    MXFP4 {},
198}
199
200// Common fields for all variants
201#[derive(Deserialize)]
202struct RawConfig {
203    quant_method: Option<String>,
204    bits: Option<usize>,
205    group_size: Option<usize>,
206    checkpoint_format: Option<String>,
207    weight_block_size: Option<Vec<usize>>,
208    bnb_4bit_quant_type: Option<String>,
209}
210
211// Custom deserializer implementation
212impl<'de> Deserialize<'de> for QuantizedConfig {
213    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
214    where
215        D: Deserializer<'de>,
216    {
217        let raw = RawConfig::deserialize(deserializer)?;
218
219        match &raw.quant_method {
220            Some(m) if m == "gptq" || m == "awq" => {
221                let bits = raw
222                    .bits
223                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
224                let group_size = raw
225                    .group_size
226                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
227                Ok(QuantizedConfig::GptqAwq {
228                    bits,
229                    group_size,
230                    checkpoint_format: raw.checkpoint_format,
231                    is_awq: m == "awq",
232                })
233            }
234            Some(m) if m == "fp8" => {
235                let weight_block_size = raw
236                    .weight_block_size
237                    .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
238                Ok(QuantizedConfig::Fp8 { weight_block_size })
239            }
240            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
241                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
242            }),
243            Some(m) if m == "afq" => {
244                let bits = raw
245                    .bits
246                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
247                let group_size = raw
248                    .group_size
249                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
250                Ok(QuantizedConfig::Afq { bits, group_size })
251            }
252            Some(m) if m == "mxfp4" => {
253                Ok(QuantizedConfig::MXFP4 {  })
254            }
255            None => {
256                let bits = raw
257                    .bits
258                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
259                let group_size = raw
260                    .group_size
261                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
262                Ok(QuantizedConfig::Afq { bits, group_size })
263            }
264            Some(unknown_method) => {
265                Err(serde::de::Error::custom(format!(
266                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
267                )))
268            },
269        }
270    }
271}
272
273impl QuantizedConfig {
274    pub fn name(&self) -> &'static str {
275        match self {
276            Self::GptqAwq { .. } => "gptq",
277            Self::Fp8 { .. } => "fp8",
278            Self::Bitsandbytes { .. } => "bitsandbytes",
279            Self::Afq { .. } => "afq",
280            Self::MXFP4 { .. } => "mxfp4",
281        }
282    }
283
284    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
285        match self {
286            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
287            Self::Fp8 { .. } => "8 bits".to_string(),
288            Self::Bitsandbytes {
289                bnb_4bit_quant_type: Some(_),
290            } => "4 bits".to_string(),
291            Self::Bitsandbytes {
292                bnb_4bit_quant_type: None,
293            } => "8 bits".to_string(),
294            Self::Afq { bits, .. } => format!("{bits} bits"),
295            Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
296        }
297    }
298
299    pub fn pack_factor(&self, dtype: DType) -> usize {
300        match self {
301            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
302                2 => IsqType::Q2K.pack_factor(dtype),
303                3 => IsqType::Q3K.pack_factor(dtype),
304                4 => IsqType::Q4K.pack_factor(dtype),
305                5 => IsqType::Q5K.pack_factor(dtype),
306                6 => IsqType::Q6K.pack_factor(dtype),
307                8 => IsqType::Q8_0.pack_factor(dtype),
308                40 => 4, // mxfp4: 2 FP4 values per byte = factor of 4
309                other => panic!("Unexpected bits in `pack_factor` {other}"),
310            },
311            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
312            Self::Bitsandbytes {
313                bnb_4bit_quant_type: Some(_),
314            }
315            | Self::Bitsandbytes {
316                bnb_4bit_quant_type: None,
317            } => IsqType::Q4K.pack_factor(dtype),
318            Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
319        }
320    }
321}
322
323#[derive(Debug, Clone)]
324pub enum QuantMethodConfig {
325    GptqAwq {
326        bits: i32,
327        use_exllama: bool,
328        q_weight: Tensor,
329        qzeros: Option<Tensor>,
330        scales: Tensor,
331        g_idx: Option<Tensor>,
332        bias: Option<Tensor>,
333        workspace: Option<Tensor>,
334        is_marlin: bool,
335        is_awq: bool,
336    },
337    Gguf {
338        q_weight: Arc<QTensor>,
339        b: Option<Tensor>,
340    },
341    Unquantized(Linear),
342    Hqq {
343        tensor: Tensor,
344        bits: HqqBits,
345        group_size: NonZeroUsize,
346        axis: HqqAxis,
347        optimization_steps: Option<usize>,
348        round_zeros: Option<bool>,
349        channel_wise: Option<bool>,
350        bias: Option<Tensor>,
351    },
352    Dummy,
353    FP8 {
354        lin: Linear,
355        dtype: DType,
356    },
357    Bnb {
358        weight: Tensor,
359        bias: Option<Tensor>,
360        params: BnbQuantParams,
361        quant_ty: BnbQuantType,
362    },
363    BlockwiseFP8 {
364        weight: Tensor,
365        weight_scale_inv: Tensor,
366        bias: Option<Tensor>,
367        dequant_dtype: DType,
368        weight_block_size: Vec<usize>,
369    },
370    Afq {
371        weight: Tensor,
372        bias: Option<Tensor>,
373        bits: AfqBits,
374        group_size: AfqGroupSize,
375    },
376    MXFP4 {
377        blocks: Tensor,
378        scales: Tensor,
379        bias: Option<Tensor>,
380    },
381}
382
383/// Device/configurable intelligent matrix multiplication
384/// - Handles limitation of `accelerate` which requires f32
385pub struct MatMul;
386
387impl MatMul {
388    /// Compute matrix-matrix product.
389    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
390        #[cfg(feature = "accelerate")]
391        {
392            let original_dtype = a.dtype();
393            a.to_dtype(DType::F32)?
394                .matmul(&b.to_dtype(DType::F32)?)?
395                .to_dtype(original_dtype)
396        }
397        #[cfg(not(feature = "accelerate"))]
398        {
399            if a.device().is_cpu() {
400                let original_dtype = a.dtype();
401                a.to_dtype(DType::F16)?
402                    .matmul(&b.to_dtype(DType::F16)?)?
403                    .to_dtype(original_dtype)
404            } else {
405                a.matmul(b)
406            }
407        }
408    }
409
410    /// Compute matrix-matrix product.
411    /// The result will be divided by the `scale` parameter in an affine division.
412    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
413        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
414        self.matmul(a, b)? / scale
415    }
416
417    /// Compute matrix-matrix product.
418    /// The result will be divided by the `scale` parameter in an affine multiplication.
419    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
420        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
421        self.matmul(a, b)? * scale
422    }
423
424    /// Compute quantized matrix-matrix product.
425    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
426        matmul.forward(x)
427    }
428
429    /// Compute quantized matrix-matrix product.
430    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
431        matmul.forward(x)
432    }
433}
434
435/// Device/configurable intelligent convolution
436/// - Handles limitation of cpu which requires f32
437pub struct Convolution;
438
439impl Convolution {
440    pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
441        if x.device().is_cpu() {
442            let original_dtype = x.dtype();
443            Conv1d::new(
444                layer.weight().to_dtype(DType::F32)?,
445                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
446                *layer.config(),
447            )
448            .forward(&x.to_dtype(DType::F32)?)?
449            .to_dtype(original_dtype)
450        } else {
451            layer.forward(x)
452        }
453    }
454
455    pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
456        if x.device().is_cpu() {
457            let original_dtype = x.dtype();
458            Conv2d::new(
459                layer.weight().to_dtype(DType::F32)?,
460                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
461                *layer.config(),
462            )
463            .forward(&x.to_dtype(DType::F32)?)?
464            .to_dtype(original_dtype)
465        } else {
466            layer.forward(x)
467        }
468    }
469}
470
471#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
472pub enum IsqType {
473    Q4_0,
474    Q4_1,
475    Q5_0,
476    Q5_1,
477    Q8_0,
478    Q8_1,
479    Q2K,
480    Q3K,
481    Q4K,
482    Q5K,
483    Q6K,
484    Q8K,
485    HQQ8,
486    HQQ4,
487    // HQQ3,
488    // HQQ2,
489    // HQQ1,
490    F8E4M3,
491    AFQ8,
492    AFQ6,
493    AFQ4,
494    AFQ3,
495    AFQ2,
496}
497
498impl IsqType {
499    /// Factor by which the weight size is reduced over the given dtype.
500    /// original size / pack factor = quantized size
501    pub fn pack_factor(&self, dtype: DType) -> usize {
502        match self {
503            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
504                .div_ceil(GgmlDType::Q4_0.type_size()),
505            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
506                .div_ceil(GgmlDType::Q4_1.type_size()),
507            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
508                .div_ceil(GgmlDType::Q5_0.type_size()),
509            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
510                .div_ceil(GgmlDType::Q5_1.type_size()),
511            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
512                .div_ceil(GgmlDType::Q8_0.type_size()),
513            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
514                .div_ceil(GgmlDType::Q8_1.type_size()),
515            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
516                .div_ceil(GgmlDType::Q2K.type_size()),
517            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
518                .div_ceil(GgmlDType::Q3K.type_size()),
519            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
520                .div_ceil(GgmlDType::Q4K.type_size()),
521            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
522                .div_ceil(GgmlDType::Q5K.type_size()),
523            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
524                .div_ceil(GgmlDType::Q6K.type_size()),
525            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
526                .div_ceil(GgmlDType::Q8K.type_size()),
527            // Estimates
528            Self::HQQ4 => 4,
529            Self::HQQ8 => 2,
530            Self::F8E4M3 => 2,
531        }
532    }
533
534    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
535        match self {
536            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
537            IsqType::HQQ4
538            | IsqType::HQQ8
539            | IsqType::AFQ2
540            | IsqType::AFQ3
541            | IsqType::AFQ4
542            | IsqType::AFQ6
543            | IsqType::AFQ8 => {
544                // Use 1 because our HQQ quantizes on the GPU
545                Some(1.try_into().unwrap())
546            }
547            IsqType::F8E4M3 => None,
548            IsqType::Q2K
549            | IsqType::Q3K
550            | IsqType::Q4K
551            | IsqType::Q4_0
552            | IsqType::Q4_1
553            | IsqType::Q5K
554            | IsqType::Q5_0
555            | IsqType::Q5_1
556            | IsqType::Q6K
557            | IsqType::Q8K
558            | IsqType::Q8_0
559            | IsqType::Q8_1 => None,
560        }
561    }
562}
563
564impl TryFrom<IsqType> for GgmlDType {
565    type Error = candle_core::Error;
566
567    fn try_from(value: IsqType) -> Result<Self> {
568        let tp = match value {
569            IsqType::Q2K => Self::Q2K,
570            IsqType::Q3K => Self::Q3K,
571            IsqType::Q4K => Self::Q4K,
572            IsqType::Q4_0 => Self::Q4_0,
573            IsqType::Q4_1 => Self::Q4_1,
574            IsqType::Q5K => Self::Q5K,
575            IsqType::Q5_0 => Self::Q5_0,
576            IsqType::Q5_1 => Self::Q5_1,
577            IsqType::Q6K => Self::Q6K,
578            IsqType::Q8K => Self::Q8K,
579            IsqType::Q8_0 => Self::Q8_0,
580            IsqType::Q8_1 => Self::Q8_1,
581            _ => candle_core::bail!("Expected valid GGML ISQ type."),
582        };
583        #[cfg(feature = "cuda")]
584        {
585            if !matches!(
586                tp,
587                GgmlDType::Q4_0
588                    | GgmlDType::Q4_1
589                    | GgmlDType::Q5_0
590                    | GgmlDType::Q5_1
591                    | GgmlDType::Q8_0
592                    | GgmlDType::Q2K
593                    | GgmlDType::Q3K
594                    | GgmlDType::Q4K
595                    | GgmlDType::Q5K
596                    | GgmlDType::Q6K
597            ) {
598                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
599            }
600        }
601        Ok(tp)
602    }
603}
604
605impl TryFrom<GgmlDType> for IsqType {
606    type Error = candle_core::Error;
607
608    fn try_from(value: GgmlDType) -> Result<Self> {
609        match value {
610            GgmlDType::Q2K => Ok(Self::Q2K),
611            GgmlDType::Q3K => Ok(Self::Q3K),
612            GgmlDType::Q4K => Ok(Self::Q4K),
613            GgmlDType::Q5K => Ok(Self::Q5K),
614            GgmlDType::Q6K => Ok(Self::Q6K),
615            GgmlDType::Q4_0 => Ok(Self::Q4_0),
616            GgmlDType::Q4_1 => Ok(Self::Q4_1),
617            GgmlDType::Q5_0 => Ok(Self::Q5_0),
618            GgmlDType::Q5_1 => Ok(Self::Q5_1),
619            GgmlDType::Q8_0 => Ok(Self::Q8_0),
620            GgmlDType::Q8_1 => Ok(Self::Q8_1),
621            GgmlDType::Q8K => Ok(Self::Q8K),
622            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
623                candle_core::bail!("Expected valid GGML ISQ type.")
624            }
625        }
626    }
627}
628
629#[derive(Debug, Clone, Copy)]
630pub enum QuantizedSerdeType {
631    Gguf = 0,
632    Unquant = 1,
633    Hqq = 2,
634    Fp8 = 3,
635    Afq = 4,
636}
637
638impl TryFrom<usize> for QuantizedSerdeType {
639    type Error = candle_core::Error;
640    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
641        match value {
642            0 => Ok(Self::Gguf),
643            1 => Ok(Self::Unquant),
644            2 => Ok(Self::Hqq),
645            3 => Ok(Self::Fp8),
646            4 => Ok(Self::Afq),
647            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
648        }
649    }
650}
651
652pub trait QuantizedSerde {
653    fn name(&self) -> &'static str;
654    fn isq_serde_supported(&self) -> bool {
655        false
656    }
657    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
658        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
659    }
660    fn deserialize(
661        _data: Cow<[u8]>,
662        _device: &Device,
663        _comm: &Arc<crate::Comm>,
664        _guard: QuantizeOntoGuard,
665    ) -> Result<Arc<dyn QuantMethod>>
666    where
667        Self: Sized,
668    {
669        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
670    }
671    fn deserialize_ext_bias(
672        _data: Cow<[u8]>,
673        _device: &Device,
674        _guard: QuantizeOntoGuard,
675    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
676    where
677        Self: Sized,
678    {
679        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
680    }
681    /// NOT meant for external calling
682    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
683        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
684    }
685}
686
687/// Used to gate access to quantizing onto the host device
688#[derive(Clone, Debug)]
689#[allow(unused)]
690pub struct QuantizeOntoGuard {
691    pub inner: Arc<Mutex<()>>,
692}
693
694/// Real (for Metal) and Fake (for CUDA)
695pub enum QuantizeOntoDropGuard<'a> {
696    Real(MutexGuard<'a, ()>),
697    Fake,
698}
699
700impl Default for QuantizeOntoGuard {
701    fn default() -> Self {
702        Self::new()
703    }
704}
705
706impl QuantizeOntoGuard {
707    pub fn new() -> Self {
708        QuantizeOntoGuard {
709            inner: Arc::new(Mutex::new(())),
710        }
711    }
712
713    /// Acquire the quantize drop guard to protect the critical section.
714    ///
715    /// On metal, this waits for outstanding work to finish to avoid "A command encoder is already encoding to this command buffer"
716    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
717        #[cfg(feature = "cuda")]
718        {
719            let _ = device;
720            QuantizeOntoDropGuard::Fake
721        }
722
723        #[cfg(not(feature = "cuda"))]
724        {
725            #[cfg(feature = "metal")]
726            if let Device::Metal(dev) = device {
727                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
728                dev.wait_until_completed()
729                    .expect("Failed to flush command buffer.");
730            }
731            #[cfg(not(feature = "metal"))]
732            let _ = device;
733
734            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
735        }
736    }
737}
738
739pub enum DistributedKind {
740    ColumnParallel,
741    RowParallel,
742    Replicated,
743}
744
745/// Quantized method for a quantized matmul.
746pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
747    fn new(method: QuantMethodConfig) -> Result<Self>
748    where
749        Self: Sized;
750
751    fn dequantize_w(&self) -> Result<Tensor>;
752
753    /// Compute matmul of `self` and `a`. `self` should contain the weights.
754    /// Automatically cast to required quantization activation type and back
755    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
756        let original_ty = a.dtype();
757        let a = if let Some(t) = self.quantized_act_type() {
758            a.to_dtype(t)?
759        } else {
760            a.clone()
761        };
762        self.forward(&a)?.to_dtype(original_ty)
763    }
764
765    /// Compute matmul of `self` and `a`. `self` should contain the weights.
766    fn forward(&self, a: &Tensor) -> Result<Tensor>;
767
768    /// Compute matmul of `self` and `a`. `self` should contain the weights.
769    /// Automatically cast to required quantization activation type and back.
770    ///
771    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
772    /// then the indices are (n_tokens, n_experts).
773    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
774        let original_ty = a.dtype();
775        let a = if let Some(t) = self.quantized_act_type() {
776            a.to_dtype(t)?
777        } else {
778            a.clone()
779        };
780        self.gather_forward(&a, indices)?.to_dtype(original_ty)
781    }
782
783    /// Compute matmul of `self` and `a`. `self` should contain the weights.
784    ///
785    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
786    /// then the indices are (n_tokens, n_experts).
787    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
788        candle_core::bail!(
789            "{} does not support `gather_forward`. Please raise an issue.",
790            self.name()
791        )
792    }
793
794    /// If a quantized method, return the activation dtype.
795    fn quantized_act_type(&self) -> Option<DType>;
796
797    /// Weight dtype and device
798    fn dtype_and_device(&self) -> (DType, Device);
799
800    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
801    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
802
803    /// If the quant is backed by a qmatmul.
804    fn apply_isq(
805        self: Arc<Self>,
806        dtype: Option<IsqType>,
807        device: Device,
808        n_quantized: &AtomicUsize,
809        imatrix_weight: Option<Vec<f32>>,
810        guard: QuantizeOntoGuard,
811    ) -> Result<Arc<dyn QuantMethod>>;
812
813    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
814        None
815    }
816
817    /// Begin tracking stats into an ImatrixLayerStats
818    fn begin_track_stats(&mut self) -> Result<()> {
819        candle_core::bail!("`{}` does not support tracking stats.", self.name())
820    }
821
822    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
823    fn end_track_stats(&self) -> Result<Tensor> {
824        candle_core::bail!("`{}` does not support tracking stats.", self.name())
825    }
826
827    fn is_distributed(&self) -> Option<DistributedKind> {
828        None
829    }
830}
831
832impl Module for dyn QuantMethod {
833    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
834        Self::forward(self, xs)
835    }
836}
837
838pub fn linear_no_bias(
839    in_dim: usize,
840    out_dim: usize,
841    config: &Option<QuantizedConfig>,
842    vb: ShardedVarBuilder,
843) -> Result<Arc<dyn QuantMethod>> {
844    let base_vb = vb.clone();
845    let vb = if should_apply_immediate_isq(&vb) {
846        vb.set_device(Device::Cpu)
847    } else {
848        vb
849    };
850
851    let layer = if let Some(quant_conf) = &config {
852        match quant_conf {
853            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
854            QuantizedConfig::Fp8 { .. } => {
855                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
856            }
857            QuantizedConfig::Bitsandbytes { .. } => {
858                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
859            }
860            QuantizedConfig::Afq { .. } => {
861                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
862            }
863            QuantizedConfig::MXFP4 {} => {
864                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
865            }
866        }
867    } else {
868        // Handle the case where the layer is dummy (no tensors)
869        if !vb.contains_tensor("weight") {
870            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
871            Arc::new(layer) as Arc<dyn QuantMethod>
872        } else {
873            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
874            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
875
876            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
877                Linear::new(weight, None),
878            ))?;
879            Arc::new(layer) as Arc<dyn QuantMethod>
880        }
881    };
882    apply_immediate_isq(layer, base_vb)
883}
884
885pub fn linear(
886    in_dim: usize,
887    out_dim: usize,
888    config: &Option<QuantizedConfig>,
889    vb: ShardedVarBuilder,
890) -> Result<Arc<dyn QuantMethod>> {
891    let base_vb = vb.clone();
892    let vb = if should_apply_immediate_isq(&vb) {
893        vb.set_device(Device::Cpu)
894    } else {
895        vb
896    };
897
898    let layer = if let Some(quant_conf) = &config {
899        match quant_conf {
900            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
901            QuantizedConfig::Fp8 { .. } => {
902                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
903            }
904            QuantizedConfig::Bitsandbytes { .. } => {
905                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
906            }
907            QuantizedConfig::Afq { .. } => {
908                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
909            }
910            QuantizedConfig::MXFP4 {} => {
911                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
912            }
913        }
914    } else {
915        // Handle the case where the layer is dummy (no tensors)
916        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
917            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
918            Arc::new(layer) as Arc<dyn QuantMethod>
919        } else {
920            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
921            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
922            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
923
924            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
925                Linear::new(weight, Some(bias)),
926            ))?;
927            Arc::new(layer) as Arc<dyn QuantMethod>
928        }
929    };
930    apply_immediate_isq(layer, base_vb)
931}
932
933pub fn linear_b(
934    in_dim: usize,
935    out_dim: usize,
936    bias: bool,
937    config: &Option<QuantizedConfig>,
938    vb: ShardedVarBuilder,
939) -> Result<Arc<dyn QuantMethod>> {
940    if bias {
941        linear(in_dim, out_dim, config, vb)
942    } else {
943        linear_no_bias(in_dim, out_dim, config, vb)
944    }
945}