mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
44pub use blockwise_fp8::{
45    blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
46};
47pub use distributed::{
48    layers::{
49        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
50        ReplicatedLayer, RowParallelLayer,
51    },
52    socket::{Client, Server},
53    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
54};
55pub use dummy::DummyLayer;
56pub use fp8::FP8Linear;
57pub use gguf::GgufMatMul;
58pub use gptq::GptqLayer;
59pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
60pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
61pub use lora::{
62    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
63    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
64};
65pub use mxfp4::MXFP4Layer;
66pub use unquantized::UnquantLinear;
67#[cfg(feature = "cuda")]
68pub use utils::gptoss_swiglu_fused;
69#[cfg(feature = "cuda")]
70pub use utils::gptoss_swiglu_interleaved;
71pub use utils::isq::apply_immediate_isq;
72#[cfg(feature = "cuda")]
73pub use utils::softmax_with_sinks;
74pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
75pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
76
77use candle_nn::{Conv1d, Conv2d, Linear, Module};
78use serde::{Deserialize, Deserializer, Serialize};
79
80#[derive(Clone, Debug)]
81pub struct ImmediateIsqParams {
82    pub guard: QuantizeOntoGuard,
83    pub ty: Option<IsqType>,
84    pub predicates: Vec<Regex>,
85    pub overrides: Vec<ImmediateIsqOverride>,
86}
87
88#[derive(Clone, Debug)]
89pub struct ImmediateIsqOverride {
90    pub predicate: Regex,
91    pub ty: Option<IsqType>,
92    pub device: Option<Device>,
93}
94
95#[derive(Clone, Debug)]
96pub struct ImmediateIsqMatch {
97    pub ty: IsqType,
98    pub device: Option<Device>,
99}
100
101thread_local! {
102    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
103}
104
105pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
106    set_immediate_isq_with_overrides(isq, predicates, Vec::new());
107}
108
109pub fn set_immediate_isq_with_overrides(
110    isq: Option<IsqType>,
111    predicates: Vec<Regex>,
112    overrides: Vec<ImmediateIsqOverride>,
113) {
114    ENGINE_IMMEDIATE_ISQ.with(|cell| {
115        *cell.borrow_mut() = Some(ImmediateIsqParams {
116            guard: QuantizeOntoGuard::new(),
117            ty: isq,
118            predicates,
119            overrides,
120        });
121    });
122}
123
124pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
125    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
126}
127
128pub fn clear_immediate_isq() {
129    ENGINE_IMMEDIATE_ISQ.with(|cell| {
130        *cell.borrow_mut() = None;
131    });
132}
133
134pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
135    immediate_isq_match(vb).is_some()
136}
137
138pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
139    let immediate_isq = get_immediate_isq()?;
140    // Add a .weight to match the ISQ regexes!
141    let prefix = format!("{}.weight", vb.prefix());
142    resolve_immediate_isq(&immediate_isq, &prefix)
143}
144
145fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
146    if let Some(override_hit) = params
147        .overrides
148        .iter()
149        .find(|override_pred| override_pred.predicate.is_match(prefix))
150    {
151        if let Some(ty) = override_hit.ty.or(params.ty) {
152            return Some(ImmediateIsqMatch {
153                ty,
154                device: override_hit.device.clone(),
155            });
156        }
157        return None;
158    }
159
160    if let Some(ty) = params.ty {
161        if params
162            .predicates
163            .iter()
164            .any(|predicate| predicate.is_match(prefix))
165        {
166            return Some(ImmediateIsqMatch { ty, device: None });
167        }
168    }
169
170    None
171}
172
173#[derive(Debug, Clone, Serialize)]
174#[serde(tag = "quant_method", rename_all = "lowercase")]
175pub enum QuantizedConfig {
176    GptqAwq {
177        bits: usize,
178        group_size: usize,
179        checkpoint_format: Option<String>,
180        is_awq: bool,
181    },
182    Fp8 {
183        weight_block_size: Vec<usize>,
184    },
185    Bitsandbytes {
186        bnb_4bit_quant_type: Option<String>,
187    },
188    Afq {
189        bits: usize,
190        group_size: usize,
191    },
192    MXFP4 {},
193}
194
195// Common fields for all variants
196#[derive(Deserialize)]
197struct RawConfig {
198    quant_method: Option<String>,
199    bits: Option<usize>,
200    group_size: Option<usize>,
201    checkpoint_format: Option<String>,
202    weight_block_size: Option<Vec<usize>>,
203    bnb_4bit_quant_type: Option<String>,
204}
205
206// Custom deserializer implementation
207impl<'de> Deserialize<'de> for QuantizedConfig {
208    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
209    where
210        D: Deserializer<'de>,
211    {
212        let raw = RawConfig::deserialize(deserializer)?;
213
214        match &raw.quant_method {
215            Some(m) if m == "gptq" || m == "awq" => {
216                let bits = raw
217                    .bits
218                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
219                let group_size = raw
220                    .group_size
221                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
222                Ok(QuantizedConfig::GptqAwq {
223                    bits,
224                    group_size,
225                    checkpoint_format: raw.checkpoint_format,
226                    is_awq: m == "awq",
227                })
228            }
229            Some(m) if m == "fp8" => {
230                let weight_block_size = raw
231                    .weight_block_size
232                    .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
233                Ok(QuantizedConfig::Fp8 { weight_block_size })
234            }
235            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
236                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
237            }),
238            Some(m) if m == "afq" => {
239                let bits = raw
240                    .bits
241                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
242                let group_size = raw
243                    .group_size
244                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
245                Ok(QuantizedConfig::Afq { bits, group_size })
246            }
247            Some(m) if m == "mxfp4" => {
248                Ok(QuantizedConfig::MXFP4 {  })
249            }
250            None => {
251                let bits = raw
252                    .bits
253                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
254                let group_size = raw
255                    .group_size
256                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
257                Ok(QuantizedConfig::Afq { bits, group_size })
258            }
259            Some(unknown_method) => {
260                Err(serde::de::Error::custom(format!(
261                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
262                )))
263            },
264        }
265    }
266}
267
268impl QuantizedConfig {
269    pub fn name(&self) -> &'static str {
270        match self {
271            Self::GptqAwq { .. } => "gptq",
272            Self::Fp8 { .. } => "fp8",
273            Self::Bitsandbytes { .. } => "bitsandbytes",
274            Self::Afq { .. } => "afq",
275            Self::MXFP4 { .. } => "mxfp4",
276        }
277    }
278
279    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
280        match self {
281            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
282            Self::Fp8 { .. } => "8 bits".to_string(),
283            Self::Bitsandbytes {
284                bnb_4bit_quant_type: Some(_),
285            } => "4 bits".to_string(),
286            Self::Bitsandbytes {
287                bnb_4bit_quant_type: None,
288            } => "8 bits".to_string(),
289            Self::Afq { bits, .. } => format!("{bits} bits"),
290            Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
291        }
292    }
293
294    pub fn pack_factor(&self, dtype: DType) -> usize {
295        match self {
296            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
297                2 => IsqType::Q2K.pack_factor(dtype),
298                3 => IsqType::Q3K.pack_factor(dtype),
299                4 => IsqType::Q4K.pack_factor(dtype),
300                5 => IsqType::Q5K.pack_factor(dtype),
301                6 => IsqType::Q6K.pack_factor(dtype),
302                8 => IsqType::Q8_0.pack_factor(dtype),
303                40 => 4, // mxfp4: 2 FP4 values per byte = factor of 4
304                other => panic!("Unexpected bits in `pack_factor` {other}"),
305            },
306            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
307            Self::Bitsandbytes {
308                bnb_4bit_quant_type: Some(_),
309            }
310            | Self::Bitsandbytes {
311                bnb_4bit_quant_type: None,
312            } => IsqType::Q4K.pack_factor(dtype),
313            Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
314        }
315    }
316}
317
318#[derive(Debug, Clone)]
319pub enum QuantMethodConfig {
320    GptqAwq {
321        bits: i32,
322        use_exllama: bool,
323        q_weight: Tensor,
324        qzeros: Option<Tensor>,
325        scales: Tensor,
326        g_idx: Option<Tensor>,
327        bias: Option<Tensor>,
328        workspace: Option<Tensor>,
329        is_marlin: bool,
330        is_awq: bool,
331    },
332    Gguf {
333        q_weight: Arc<QTensor>,
334        b: Option<Tensor>,
335    },
336    Unquantized(Linear),
337    Hqq {
338        tensor: Tensor,
339        bits: HqqBits,
340        group_size: NonZeroUsize,
341        axis: HqqAxis,
342        optimization_steps: Option<usize>,
343        round_zeros: Option<bool>,
344        channel_wise: Option<bool>,
345        bias: Option<Tensor>,
346    },
347    Dummy,
348    FP8 {
349        lin: Linear,
350        dtype: DType,
351    },
352    Bnb {
353        weight: Tensor,
354        bias: Option<Tensor>,
355        params: BnbQuantParams,
356        quant_ty: BnbQuantType,
357    },
358    BlockwiseFP8 {
359        weight: Tensor,
360        weight_scale_inv: Tensor,
361        bias: Option<Tensor>,
362        dequant_dtype: DType,
363        weight_block_size: Vec<usize>,
364    },
365    Afq {
366        weight: Tensor,
367        bias: Option<Tensor>,
368        bits: AfqBits,
369        group_size: AfqGroupSize,
370    },
371    MXFP4 {
372        blocks: Tensor,
373        scales: Tensor,
374        bias: Option<Tensor>,
375    },
376}
377
378/// Device/configurable intelligent matrix multiplication
379/// - Handles limitation of `accelerate` which requires f32
380pub struct MatMul;
381
382impl MatMul {
383    /// Compute matrix-matrix product.
384    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
385        #[cfg(feature = "accelerate")]
386        {
387            let original_dtype = a.dtype();
388            a.to_dtype(DType::F32)?
389                .matmul(&b.to_dtype(DType::F32)?)?
390                .to_dtype(original_dtype)
391        }
392        #[cfg(not(feature = "accelerate"))]
393        {
394            if a.device().is_cpu() {
395                let original_dtype = a.dtype();
396                a.to_dtype(DType::F16)?
397                    .matmul(&b.to_dtype(DType::F16)?)?
398                    .to_dtype(original_dtype)
399            } else {
400                a.matmul(b)
401            }
402        }
403    }
404
405    /// Compute matrix-matrix product.
406    /// The result will be divided by the `scale` parameter in an affine division.
407    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
408        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
409        self.matmul(a, b)? / scale
410    }
411
412    /// Compute matrix-matrix product.
413    /// The result will be divided by the `scale` parameter in an affine multiplication.
414    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
415        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
416        self.matmul(a, b)? * scale
417    }
418
419    /// Compute quantized matrix-matrix product.
420    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
421        matmul.forward(x)
422    }
423
424    /// Compute quantized matrix-matrix product.
425    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
426        matmul.forward(x)
427    }
428}
429
430/// Device/configurable intelligent convolution
431/// - Handles limitation of cpu which requires f32
432pub struct Convolution;
433
434impl Convolution {
435    pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
436        if x.device().is_cpu() {
437            let original_dtype = x.dtype();
438            Conv1d::new(
439                layer.weight().to_dtype(DType::F32)?,
440                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
441                *layer.config(),
442            )
443            .forward(&x.to_dtype(DType::F32)?)?
444            .to_dtype(original_dtype)
445        } else {
446            layer.forward(x)
447        }
448    }
449
450    pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
451        if x.device().is_cpu() {
452            let original_dtype = x.dtype();
453            Conv2d::new(
454                layer.weight().to_dtype(DType::F32)?,
455                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
456                *layer.config(),
457            )
458            .forward(&x.to_dtype(DType::F32)?)?
459            .to_dtype(original_dtype)
460        } else {
461            layer.forward(x)
462        }
463    }
464}
465
466#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
467pub enum IsqType {
468    Q4_0,
469    Q4_1,
470    Q5_0,
471    Q5_1,
472    Q8_0,
473    Q8_1,
474    Q2K,
475    Q3K,
476    Q4K,
477    Q5K,
478    Q6K,
479    Q8K,
480    HQQ8,
481    HQQ4,
482    // HQQ3,
483    // HQQ2,
484    // HQQ1,
485    F8E4M3,
486    AFQ8,
487    AFQ6,
488    AFQ4,
489    AFQ3,
490    AFQ2,
491}
492
493impl IsqType {
494    /// Factor by which the weight size is reduced over the given dtype.
495    /// original size / pack factor = quantized size
496    pub fn pack_factor(&self, dtype: DType) -> usize {
497        match self {
498            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
499                .div_ceil(GgmlDType::Q4_0.type_size()),
500            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
501                .div_ceil(GgmlDType::Q4_1.type_size()),
502            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
503                .div_ceil(GgmlDType::Q5_0.type_size()),
504            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
505                .div_ceil(GgmlDType::Q5_1.type_size()),
506            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
507                .div_ceil(GgmlDType::Q8_0.type_size()),
508            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
509                .div_ceil(GgmlDType::Q8_1.type_size()),
510            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
511                .div_ceil(GgmlDType::Q2K.type_size()),
512            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
513                .div_ceil(GgmlDType::Q3K.type_size()),
514            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
515                .div_ceil(GgmlDType::Q4K.type_size()),
516            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
517                .div_ceil(GgmlDType::Q5K.type_size()),
518            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
519                .div_ceil(GgmlDType::Q6K.type_size()),
520            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
521                .div_ceil(GgmlDType::Q8K.type_size()),
522            // Estimates
523            Self::HQQ4 => 4,
524            Self::HQQ8 => 2,
525            Self::F8E4M3 => 2,
526        }
527    }
528
529    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
530        match self {
531            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
532            IsqType::HQQ4
533            | IsqType::HQQ8
534            | IsqType::AFQ2
535            | IsqType::AFQ3
536            | IsqType::AFQ4
537            | IsqType::AFQ6
538            | IsqType::AFQ8 => {
539                // Use 1 because our HQQ quantizes on the GPU
540                Some(1.try_into().unwrap())
541            }
542            IsqType::F8E4M3 => None,
543            IsqType::Q2K
544            | IsqType::Q3K
545            | IsqType::Q4K
546            | IsqType::Q4_0
547            | IsqType::Q4_1
548            | IsqType::Q5K
549            | IsqType::Q5_0
550            | IsqType::Q5_1
551            | IsqType::Q6K
552            | IsqType::Q8K
553            | IsqType::Q8_0
554            | IsqType::Q8_1 => None,
555        }
556    }
557}
558
559impl TryFrom<IsqType> for GgmlDType {
560    type Error = candle_core::Error;
561
562    fn try_from(value: IsqType) -> Result<Self> {
563        let tp = match value {
564            IsqType::Q2K => Self::Q2K,
565            IsqType::Q3K => Self::Q3K,
566            IsqType::Q4K => Self::Q4K,
567            IsqType::Q4_0 => Self::Q4_0,
568            IsqType::Q4_1 => Self::Q4_1,
569            IsqType::Q5K => Self::Q5K,
570            IsqType::Q5_0 => Self::Q5_0,
571            IsqType::Q5_1 => Self::Q5_1,
572            IsqType::Q6K => Self::Q6K,
573            IsqType::Q8K => Self::Q8K,
574            IsqType::Q8_0 => Self::Q8_0,
575            IsqType::Q8_1 => Self::Q8_1,
576            _ => candle_core::bail!("Expected valid GGML ISQ type."),
577        };
578        #[cfg(feature = "cuda")]
579        {
580            if !matches!(
581                tp,
582                GgmlDType::Q4_0
583                    | GgmlDType::Q4_1
584                    | GgmlDType::Q5_0
585                    | GgmlDType::Q5_1
586                    | GgmlDType::Q8_0
587                    | GgmlDType::Q2K
588                    | GgmlDType::Q3K
589                    | GgmlDType::Q4K
590                    | GgmlDType::Q5K
591                    | GgmlDType::Q6K
592            ) {
593                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
594            }
595        }
596        Ok(tp)
597    }
598}
599
600impl TryFrom<GgmlDType> for IsqType {
601    type Error = candle_core::Error;
602
603    fn try_from(value: GgmlDType) -> Result<Self> {
604        match value {
605            GgmlDType::Q2K => Ok(Self::Q2K),
606            GgmlDType::Q3K => Ok(Self::Q3K),
607            GgmlDType::Q4K => Ok(Self::Q4K),
608            GgmlDType::Q5K => Ok(Self::Q5K),
609            GgmlDType::Q6K => Ok(Self::Q6K),
610            GgmlDType::Q4_0 => Ok(Self::Q4_0),
611            GgmlDType::Q4_1 => Ok(Self::Q4_1),
612            GgmlDType::Q5_0 => Ok(Self::Q5_0),
613            GgmlDType::Q5_1 => Ok(Self::Q5_1),
614            GgmlDType::Q8_0 => Ok(Self::Q8_0),
615            GgmlDType::Q8_1 => Ok(Self::Q8_1),
616            GgmlDType::Q8K => Ok(Self::Q8K),
617            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
618                candle_core::bail!("Expected valid GGML ISQ type.")
619            }
620        }
621    }
622}
623
624#[derive(Debug, Clone, Copy)]
625pub enum QuantizedSerdeType {
626    Gguf = 0,
627    Unquant = 1,
628    Hqq = 2,
629    Fp8 = 3,
630    Afq = 4,
631}
632
633impl TryFrom<usize> for QuantizedSerdeType {
634    type Error = candle_core::Error;
635    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
636        match value {
637            0 => Ok(Self::Gguf),
638            1 => Ok(Self::Unquant),
639            2 => Ok(Self::Hqq),
640            3 => Ok(Self::Fp8),
641            4 => Ok(Self::Afq),
642            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
643        }
644    }
645}
646
647pub trait QuantizedSerde {
648    fn name(&self) -> &'static str;
649    fn isq_serde_supported(&self) -> bool {
650        false
651    }
652    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
653        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
654    }
655    fn deserialize(
656        _data: Cow<[u8]>,
657        _device: &Device,
658        _comm: &Arc<crate::Comm>,
659        _guard: QuantizeOntoGuard,
660    ) -> Result<Arc<dyn QuantMethod>>
661    where
662        Self: Sized,
663    {
664        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
665    }
666    fn deserialize_ext_bias(
667        _data: Cow<[u8]>,
668        _device: &Device,
669        _guard: QuantizeOntoGuard,
670    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
671    where
672        Self: Sized,
673    {
674        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
675    }
676    /// NOT meant for external calling
677    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
678        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
679    }
680}
681
682/// Used to gate access to quantizing onto the host device
683#[derive(Clone, Debug)]
684#[allow(unused)]
685pub struct QuantizeOntoGuard {
686    pub inner: Arc<Mutex<()>>,
687}
688
689/// Real (for Metal) and Fake (for CUDA)
690pub enum QuantizeOntoDropGuard<'a> {
691    Real(MutexGuard<'a, ()>),
692    Fake,
693}
694
695impl Default for QuantizeOntoGuard {
696    fn default() -> Self {
697        Self::new()
698    }
699}
700
701impl QuantizeOntoGuard {
702    pub fn new() -> Self {
703        QuantizeOntoGuard {
704            inner: Arc::new(Mutex::new(())),
705        }
706    }
707
708    /// Acquire the quantize drop guard to protect the critical section.
709    ///
710    /// On metal, this waits for outstanding work to finish to avoid "A command encoder is already encoding to this command buffer"
711    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
712        #[cfg(feature = "cuda")]
713        {
714            let _ = device;
715            QuantizeOntoDropGuard::Fake
716        }
717
718        #[cfg(not(feature = "cuda"))]
719        {
720            #[cfg(feature = "metal")]
721            if let Device::Metal(dev) = device {
722                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
723                dev.wait_until_completed()
724                    .expect("Failed to flush command buffer.");
725            }
726            #[cfg(not(feature = "metal"))]
727            let _ = device;
728
729            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
730        }
731    }
732}
733
734pub enum DistributedKind {
735    ColumnParallel,
736    RowParallel,
737    Replicated,
738}
739
740/// Quantized method for a quantized matmul.
741pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
742    fn new(method: QuantMethodConfig) -> Result<Self>
743    where
744        Self: Sized;
745
746    fn dequantize_w(&self) -> Result<Tensor>;
747
748    /// Compute matmul of `self` and `a`. `self` should contain the weights.
749    /// Automatically cast to required quantization activation type and back
750    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
751        let original_ty = a.dtype();
752        let a = if let Some(t) = self.quantized_act_type() {
753            a.to_dtype(t)?
754        } else {
755            a.clone()
756        };
757        self.forward(&a)?.to_dtype(original_ty)
758    }
759
760    /// Compute matmul of `self` and `a`. `self` should contain the weights.
761    fn forward(&self, a: &Tensor) -> Result<Tensor>;
762
763    /// Compute matmul of `self` and `a`. `self` should contain the weights.
764    /// Automatically cast to required quantization activation type and back.
765    ///
766    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
767    /// then the indices are (n_tokens, n_experts).
768    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
769        let original_ty = a.dtype();
770        let a = if let Some(t) = self.quantized_act_type() {
771            a.to_dtype(t)?
772        } else {
773            a.clone()
774        };
775        self.gather_forward(&a, indices)?.to_dtype(original_ty)
776    }
777
778    /// Compute matmul of `self` and `a`. `self` should contain the weights.
779    ///
780    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
781    /// then the indices are (n_tokens, n_experts).
782    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
783        candle_core::bail!(
784            "{} does not support `gather_forward`. Please raise an issue.",
785            self.name()
786        )
787    }
788
789    /// If a quantized method, return the activation dtype.
790    fn quantized_act_type(&self) -> Option<DType>;
791
792    /// Weight dtype and device
793    fn dtype_and_device(&self) -> (DType, Device);
794
795    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
796    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
797
798    /// If the quant is backed by a qmatmul.
799    fn apply_isq(
800        self: Arc<Self>,
801        dtype: Option<IsqType>,
802        device: Device,
803        n_quantized: &AtomicUsize,
804        imatrix_weight: Option<Vec<f32>>,
805        guard: QuantizeOntoGuard,
806    ) -> Result<Arc<dyn QuantMethod>>;
807
808    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
809        None
810    }
811
812    /// Begin tracking stats into an ImatrixLayerStats
813    fn begin_track_stats(&mut self) -> Result<()> {
814        candle_core::bail!("`{}` does not support tracking stats.", self.name())
815    }
816
817    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
818    fn end_track_stats(&self) -> Result<Tensor> {
819        candle_core::bail!("`{}` does not support tracking stats.", self.name())
820    }
821
822    fn is_distributed(&self) -> Option<DistributedKind> {
823        None
824    }
825}
826
827impl Module for dyn QuantMethod {
828    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
829        Self::forward(self, xs)
830    }
831}
832
833pub fn linear_no_bias(
834    in_dim: usize,
835    out_dim: usize,
836    config: &Option<QuantizedConfig>,
837    vb: ShardedVarBuilder,
838) -> Result<Arc<dyn QuantMethod>> {
839    let base_vb = vb.clone();
840    let vb = if should_apply_immediate_isq(&vb) {
841        vb.set_device(Device::Cpu)
842    } else {
843        vb
844    };
845
846    let layer = if let Some(quant_conf) = &config {
847        match quant_conf {
848            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
849            QuantizedConfig::Fp8 { .. } => {
850                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
851            }
852            QuantizedConfig::Bitsandbytes { .. } => {
853                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
854            }
855            QuantizedConfig::Afq { .. } => {
856                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
857            }
858            QuantizedConfig::MXFP4 {} => {
859                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
860            }
861        }
862    } else {
863        // Handle the case where the layer is dummy (no tensors)
864        if !vb.contains_tensor("weight") {
865            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
866            Arc::new(layer) as Arc<dyn QuantMethod>
867        } else {
868            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
869            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
870
871            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
872                Linear::new(weight, None),
873            ))?;
874            Arc::new(layer) as Arc<dyn QuantMethod>
875        }
876    };
877    apply_immediate_isq(layer, base_vb)
878}
879
880pub fn linear(
881    in_dim: usize,
882    out_dim: usize,
883    config: &Option<QuantizedConfig>,
884    vb: ShardedVarBuilder,
885) -> Result<Arc<dyn QuantMethod>> {
886    let base_vb = vb.clone();
887    let vb = if should_apply_immediate_isq(&vb) {
888        vb.set_device(Device::Cpu)
889    } else {
890        vb
891    };
892
893    let layer = if let Some(quant_conf) = &config {
894        match quant_conf {
895            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
896            QuantizedConfig::Fp8 { .. } => {
897                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
898            }
899            QuantizedConfig::Bitsandbytes { .. } => {
900                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
901            }
902            QuantizedConfig::Afq { .. } => {
903                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
904            }
905            QuantizedConfig::MXFP4 {} => {
906                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
907            }
908        }
909    } else {
910        // Handle the case where the layer is dummy (no tensors)
911        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
912            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
913            Arc::new(layer) as Arc<dyn QuantMethod>
914        } else {
915            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
916            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
917            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
918
919            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
920                Linear::new(weight, Some(bias)),
921            ))?;
922            Arc::new(layer) as Arc<dyn QuantMethod>
923        }
924    };
925    apply_immediate_isq(layer, base_vb)
926}
927
928pub fn linear_b(
929    in_dim: usize,
930    out_dim: usize,
931    bias: bool,
932    config: &Option<QuantizedConfig>,
933    vb: ShardedVarBuilder,
934) -> Result<Arc<dyn QuantMethod>> {
935    if bias {
936        linear(in_dim, out_dim, config, vb)
937    } else {
938        linear_no_bias(in_dim, out_dim, config, vb)
939    }
940}