mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
44pub use blockwise_fp8::{
45    blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
46};
47pub use distributed::{
48    layers::{
49        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
50        ReplicatedLayer, RowParallelLayer,
51    },
52    socket::{Client, Server},
53    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
54};
55pub use dummy::DummyLayer;
56pub use fp8::FP8Linear;
57pub use gguf::GgufMatMul;
58pub use gptq::GptqLayer;
59pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
60pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
61pub use lora::{
62    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
63    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
64};
65pub use mxfp4::MXFP4Layer;
66pub use unquantized::UnquantLinear;
67pub use utils::isq::apply_immediate_isq;
68pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
69pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
70
71use candle_nn::{Conv1d, Conv2d, Linear, Module};
72use serde::{Deserialize, Deserializer, Serialize};
73
74#[derive(Clone, Debug)]
75pub struct ImmediateIsqParams {
76    pub guard: QuantizeOntoGuard,
77    pub ty: Option<IsqType>,
78    pub predicates: Vec<Regex>,
79    pub overrides: Vec<ImmediateIsqOverride>,
80}
81
82#[derive(Clone, Debug)]
83pub struct ImmediateIsqOverride {
84    pub predicate: Regex,
85    pub ty: Option<IsqType>,
86    pub device: Option<Device>,
87}
88
89#[derive(Clone, Debug)]
90pub struct ImmediateIsqMatch {
91    pub ty: IsqType,
92    pub device: Option<Device>,
93}
94
95thread_local! {
96    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
97}
98
99pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
100    set_immediate_isq_with_overrides(isq, predicates, Vec::new());
101}
102
103pub fn set_immediate_isq_with_overrides(
104    isq: Option<IsqType>,
105    predicates: Vec<Regex>,
106    overrides: Vec<ImmediateIsqOverride>,
107) {
108    ENGINE_IMMEDIATE_ISQ.with(|cell| {
109        *cell.borrow_mut() = Some(ImmediateIsqParams {
110            guard: QuantizeOntoGuard::new(),
111            ty: isq,
112            predicates,
113            overrides,
114        });
115    });
116}
117
118pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
119    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
120}
121
122pub fn clear_immediate_isq() {
123    ENGINE_IMMEDIATE_ISQ.with(|cell| {
124        *cell.borrow_mut() = None;
125    });
126}
127
128pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
129    immediate_isq_match(vb).is_some()
130}
131
132pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
133    let immediate_isq = get_immediate_isq()?;
134    // Add a .weight to match the ISQ regexes!
135    let prefix = format!("{}.weight", vb.prefix());
136    resolve_immediate_isq(&immediate_isq, &prefix)
137}
138
139fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
140    if let Some(override_hit) = params
141        .overrides
142        .iter()
143        .find(|override_pred| override_pred.predicate.is_match(prefix))
144    {
145        if let Some(ty) = override_hit.ty.or(params.ty) {
146            return Some(ImmediateIsqMatch {
147                ty,
148                device: override_hit.device.clone(),
149            });
150        }
151        return None;
152    }
153
154    if let Some(ty) = params.ty {
155        if params
156            .predicates
157            .iter()
158            .any(|predicate| predicate.is_match(prefix))
159        {
160            return Some(ImmediateIsqMatch { ty, device: None });
161        }
162    }
163
164    None
165}
166
167#[derive(Debug, Clone, Serialize)]
168#[serde(tag = "quant_method", rename_all = "lowercase")]
169pub enum QuantizedConfig {
170    GptqAwq {
171        bits: usize,
172        group_size: usize,
173        checkpoint_format: Option<String>,
174        is_awq: bool,
175    },
176    Fp8 {
177        weight_block_size: Vec<usize>,
178    },
179    Bitsandbytes {
180        bnb_4bit_quant_type: Option<String>,
181    },
182    Afq {
183        bits: usize,
184        group_size: usize,
185    },
186    MXFP4 {},
187}
188
189// Common fields for all variants
190#[derive(Deserialize)]
191struct RawConfig {
192    quant_method: Option<String>,
193    bits: Option<usize>,
194    group_size: Option<usize>,
195    checkpoint_format: Option<String>,
196    weight_block_size: Option<Vec<usize>>,
197    bnb_4bit_quant_type: Option<String>,
198}
199
200// Custom deserializer implementation
201impl<'de> Deserialize<'de> for QuantizedConfig {
202    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
203    where
204        D: Deserializer<'de>,
205    {
206        let raw = RawConfig::deserialize(deserializer)?;
207
208        match &raw.quant_method {
209            Some(m) if m == "gptq" || m == "awq" => {
210                let bits = raw
211                    .bits
212                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
213                let group_size = raw
214                    .group_size
215                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
216                Ok(QuantizedConfig::GptqAwq {
217                    bits,
218                    group_size,
219                    checkpoint_format: raw.checkpoint_format,
220                    is_awq: m == "awq",
221                })
222            }
223            Some(m) if m == "fp8" => {
224                let weight_block_size = raw
225                    .weight_block_size
226                    .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
227                Ok(QuantizedConfig::Fp8 { weight_block_size })
228            }
229            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
230                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
231            }),
232            Some(m) if m == "afq" => {
233                let bits = raw
234                    .bits
235                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
236                let group_size = raw
237                    .group_size
238                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
239                Ok(QuantizedConfig::Afq { bits, group_size })
240            }
241            Some(m) if m == "mxfp4" => {
242                Ok(QuantizedConfig::MXFP4 {  })
243            }
244            None => {
245                let bits = raw
246                    .bits
247                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
248                let group_size = raw
249                    .group_size
250                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
251                Ok(QuantizedConfig::Afq { bits, group_size })
252            }
253            Some(unknown_method) => {
254                Err(serde::de::Error::custom(format!(
255                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
256                )))
257            },
258        }
259    }
260}
261
262impl QuantizedConfig {
263    pub fn name(&self) -> &'static str {
264        match self {
265            Self::GptqAwq { .. } => "gptq",
266            Self::Fp8 { .. } => "fp8",
267            Self::Bitsandbytes { .. } => "bitsandbytes",
268            Self::Afq { .. } => "afq",
269            Self::MXFP4 { .. } => "mxfp4",
270        }
271    }
272
273    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
274        match self {
275            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
276            Self::Fp8 { .. } => "8 bits".to_string(),
277            Self::Bitsandbytes {
278                bnb_4bit_quant_type: Some(_),
279            } => "4 bits".to_string(),
280            Self::Bitsandbytes {
281                bnb_4bit_quant_type: None,
282            } => "8 bits".to_string(),
283            Self::Afq { bits, .. } => format!("{bits} bits"),
284            Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
285        }
286    }
287
288    pub fn pack_factor(&self, dtype: DType) -> usize {
289        match self {
290            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
291                2 => IsqType::Q2K.pack_factor(dtype),
292                3 => IsqType::Q3K.pack_factor(dtype),
293                4 => IsqType::Q4K.pack_factor(dtype),
294                5 => IsqType::Q5K.pack_factor(dtype),
295                6 => IsqType::Q6K.pack_factor(dtype),
296                8 => IsqType::Q8_0.pack_factor(dtype),
297                40 => 4, // mxfp4: 2 FP4 values per byte = factor of 4
298                other => panic!("Unexpected bits in `pack_factor` {other}"),
299            },
300            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
301            Self::Bitsandbytes {
302                bnb_4bit_quant_type: Some(_),
303            }
304            | Self::Bitsandbytes {
305                bnb_4bit_quant_type: None,
306            } => IsqType::Q4K.pack_factor(dtype),
307            Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
308        }
309    }
310}
311
312#[derive(Debug, Clone)]
313pub enum QuantMethodConfig {
314    GptqAwq {
315        bits: i32,
316        use_exllama: bool,
317        q_weight: Tensor,
318        qzeros: Option<Tensor>,
319        scales: Tensor,
320        g_idx: Option<Tensor>,
321        bias: Option<Tensor>,
322        workspace: Option<Tensor>,
323        is_marlin: bool,
324        is_awq: bool,
325    },
326    Gguf {
327        q_weight: Arc<QTensor>,
328        b: Option<Tensor>,
329    },
330    Unquantized(Linear),
331    Hqq {
332        tensor: Tensor,
333        bits: HqqBits,
334        group_size: NonZeroUsize,
335        axis: HqqAxis,
336        optimization_steps: Option<usize>,
337        round_zeros: Option<bool>,
338        channel_wise: Option<bool>,
339        bias: Option<Tensor>,
340    },
341    Dummy,
342    FP8 {
343        lin: Linear,
344        dtype: DType,
345    },
346    Bnb {
347        weight: Tensor,
348        bias: Option<Tensor>,
349        params: BnbQuantParams,
350        quant_ty: BnbQuantType,
351    },
352    BlockwiseFP8 {
353        weight: Tensor,
354        weight_scale_inv: Tensor,
355        bias: Option<Tensor>,
356        dequant_dtype: DType,
357        weight_block_size: Vec<usize>,
358    },
359    Afq {
360        weight: Tensor,
361        bias: Option<Tensor>,
362        bits: AfqBits,
363        group_size: AfqGroupSize,
364    },
365    MXFP4 {
366        blocks: Tensor,
367        scales: Tensor,
368        bias: Option<Tensor>,
369    },
370}
371
372/// Device/configurable intelligent matrix multiplication
373/// - Handles limitation of `accelerate` which requires f32
374pub struct MatMul;
375
376impl MatMul {
377    /// Compute matrix-matrix product.
378    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
379        #[cfg(feature = "accelerate")]
380        {
381            let original_dtype = a.dtype();
382            a.to_dtype(DType::F32)?
383                .matmul(&b.to_dtype(DType::F32)?)?
384                .to_dtype(original_dtype)
385        }
386        #[cfg(not(feature = "accelerate"))]
387        {
388            if a.device().is_cpu() {
389                let original_dtype = a.dtype();
390                a.to_dtype(DType::F16)?
391                    .matmul(&b.to_dtype(DType::F16)?)?
392                    .to_dtype(original_dtype)
393            } else {
394                a.matmul(b)
395            }
396        }
397    }
398
399    /// Compute matrix-matrix product.
400    /// The result will be divided by the `scale` parameter in an affine division.
401    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
402        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
403        self.matmul(a, b)? / scale
404    }
405
406    /// Compute matrix-matrix product.
407    /// The result will be divided by the `scale` parameter in an affine multiplication.
408    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
409        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
410        self.matmul(a, b)? * scale
411    }
412
413    /// Compute quantized matrix-matrix product.
414    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
415        matmul.forward(x)
416    }
417
418    /// Compute quantized matrix-matrix product.
419    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
420        matmul.forward(x)
421    }
422}
423
424/// Device/configurable intelligent convolution
425/// - Handles limitation of cpu which requires f32
426pub struct Convolution;
427
428impl Convolution {
429    pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
430        if x.device().is_cpu() {
431            let original_dtype = x.dtype();
432            Conv1d::new(
433                layer.weight().to_dtype(DType::F32)?,
434                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
435                *layer.config(),
436            )
437            .forward(&x.to_dtype(DType::F32)?)?
438            .to_dtype(original_dtype)
439        } else {
440            layer.forward(x)
441        }
442    }
443
444    pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
445        if x.device().is_cpu() {
446            let original_dtype = x.dtype();
447            Conv2d::new(
448                layer.weight().to_dtype(DType::F32)?,
449                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
450                *layer.config(),
451            )
452            .forward(&x.to_dtype(DType::F32)?)?
453            .to_dtype(original_dtype)
454        } else {
455            layer.forward(x)
456        }
457    }
458}
459
460#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
461pub enum IsqType {
462    Q4_0,
463    Q4_1,
464    Q5_0,
465    Q5_1,
466    Q8_0,
467    Q8_1,
468    Q2K,
469    Q3K,
470    Q4K,
471    Q5K,
472    Q6K,
473    Q8K,
474    HQQ8,
475    HQQ4,
476    // HQQ3,
477    // HQQ2,
478    // HQQ1,
479    F8E4M3,
480    AFQ8,
481    AFQ6,
482    AFQ4,
483    AFQ3,
484    AFQ2,
485}
486
487impl IsqType {
488    /// Factor by which the weight size is reduced over the given dtype.
489    /// original size / pack factor = quantized size
490    pub fn pack_factor(&self, dtype: DType) -> usize {
491        match self {
492            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
493                .div_ceil(GgmlDType::Q4_0.type_size()),
494            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
495                .div_ceil(GgmlDType::Q4_1.type_size()),
496            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
497                .div_ceil(GgmlDType::Q5_0.type_size()),
498            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
499                .div_ceil(GgmlDType::Q5_1.type_size()),
500            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
501                .div_ceil(GgmlDType::Q8_0.type_size()),
502            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
503                .div_ceil(GgmlDType::Q8_1.type_size()),
504            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
505                .div_ceil(GgmlDType::Q2K.type_size()),
506            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
507                .div_ceil(GgmlDType::Q3K.type_size()),
508            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
509                .div_ceil(GgmlDType::Q4K.type_size()),
510            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
511                .div_ceil(GgmlDType::Q5K.type_size()),
512            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
513                .div_ceil(GgmlDType::Q6K.type_size()),
514            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
515                .div_ceil(GgmlDType::Q8K.type_size()),
516            // Estimates
517            Self::HQQ4 => 4,
518            Self::HQQ8 => 2,
519            Self::F8E4M3 => 2,
520        }
521    }
522
523    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
524        match self {
525            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
526            IsqType::HQQ4
527            | IsqType::HQQ8
528            | IsqType::AFQ2
529            | IsqType::AFQ3
530            | IsqType::AFQ4
531            | IsqType::AFQ6
532            | IsqType::AFQ8 => {
533                // Use 1 because our HQQ quantizes on the GPU
534                Some(1.try_into().unwrap())
535            }
536            IsqType::F8E4M3 => None,
537            IsqType::Q2K
538            | IsqType::Q3K
539            | IsqType::Q4K
540            | IsqType::Q4_0
541            | IsqType::Q4_1
542            | IsqType::Q5K
543            | IsqType::Q5_0
544            | IsqType::Q5_1
545            | IsqType::Q6K
546            | IsqType::Q8K
547            | IsqType::Q8_0
548            | IsqType::Q8_1 => None,
549        }
550    }
551}
552
553impl TryFrom<IsqType> for GgmlDType {
554    type Error = candle_core::Error;
555
556    fn try_from(value: IsqType) -> Result<Self> {
557        let tp = match value {
558            IsqType::Q2K => Self::Q2K,
559            IsqType::Q3K => Self::Q3K,
560            IsqType::Q4K => Self::Q4K,
561            IsqType::Q4_0 => Self::Q4_0,
562            IsqType::Q4_1 => Self::Q4_1,
563            IsqType::Q5K => Self::Q5K,
564            IsqType::Q5_0 => Self::Q5_0,
565            IsqType::Q5_1 => Self::Q5_1,
566            IsqType::Q6K => Self::Q6K,
567            IsqType::Q8K => Self::Q8K,
568            IsqType::Q8_0 => Self::Q8_0,
569            IsqType::Q8_1 => Self::Q8_1,
570            _ => candle_core::bail!("Expected valid GGML ISQ type."),
571        };
572        #[cfg(feature = "cuda")]
573        {
574            if !matches!(
575                tp,
576                GgmlDType::Q4_0
577                    | GgmlDType::Q4_1
578                    | GgmlDType::Q5_0
579                    | GgmlDType::Q5_1
580                    | GgmlDType::Q8_0
581                    | GgmlDType::Q2K
582                    | GgmlDType::Q3K
583                    | GgmlDType::Q4K
584                    | GgmlDType::Q5K
585                    | GgmlDType::Q6K
586            ) {
587                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
588            }
589        }
590        Ok(tp)
591    }
592}
593
594impl TryFrom<GgmlDType> for IsqType {
595    type Error = candle_core::Error;
596
597    fn try_from(value: GgmlDType) -> Result<Self> {
598        match value {
599            GgmlDType::Q2K => Ok(Self::Q2K),
600            GgmlDType::Q3K => Ok(Self::Q3K),
601            GgmlDType::Q4K => Ok(Self::Q4K),
602            GgmlDType::Q5K => Ok(Self::Q5K),
603            GgmlDType::Q6K => Ok(Self::Q6K),
604            GgmlDType::Q4_0 => Ok(Self::Q4_0),
605            GgmlDType::Q4_1 => Ok(Self::Q4_1),
606            GgmlDType::Q5_0 => Ok(Self::Q5_0),
607            GgmlDType::Q5_1 => Ok(Self::Q5_1),
608            GgmlDType::Q8_0 => Ok(Self::Q8_0),
609            GgmlDType::Q8_1 => Ok(Self::Q8_1),
610            GgmlDType::Q8K => Ok(Self::Q8K),
611            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
612                candle_core::bail!("Expected valid GGML ISQ type.")
613            }
614        }
615    }
616}
617
618#[derive(Debug, Clone, Copy)]
619pub enum QuantizedSerdeType {
620    Gguf = 0,
621    Unquant = 1,
622    Hqq = 2,
623    Fp8 = 3,
624    Afq = 4,
625}
626
627impl TryFrom<usize> for QuantizedSerdeType {
628    type Error = candle_core::Error;
629    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
630        match value {
631            0 => Ok(Self::Gguf),
632            1 => Ok(Self::Unquant),
633            2 => Ok(Self::Hqq),
634            3 => Ok(Self::Fp8),
635            4 => Ok(Self::Afq),
636            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
637        }
638    }
639}
640
641pub trait QuantizedSerde {
642    fn name(&self) -> &'static str;
643    fn isq_serde_supported(&self) -> bool {
644        false
645    }
646    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
647        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
648    }
649    fn deserialize(
650        _data: Cow<[u8]>,
651        _device: &Device,
652        _comm: &Arc<crate::Comm>,
653        _guard: QuantizeOntoGuard,
654    ) -> Result<Arc<dyn QuantMethod>>
655    where
656        Self: Sized,
657    {
658        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
659    }
660    fn deserialize_ext_bias(
661        _data: Cow<[u8]>,
662        _device: &Device,
663        _guard: QuantizeOntoGuard,
664    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
665    where
666        Self: Sized,
667    {
668        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
669    }
670    /// NOT meant for external calling
671    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
672        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
673    }
674}
675
676/// Used to gate access to quantizing onto the host device
677#[derive(Clone, Debug)]
678#[allow(unused)]
679pub struct QuantizeOntoGuard {
680    pub inner: Arc<Mutex<()>>,
681}
682
683/// Real (for Metal) and Fake (for CUDA)
684pub enum QuantizeOntoDropGuard<'a> {
685    Real(MutexGuard<'a, ()>),
686    Fake,
687}
688
689impl Default for QuantizeOntoGuard {
690    fn default() -> Self {
691        Self::new()
692    }
693}
694
695impl QuantizeOntoGuard {
696    pub fn new() -> Self {
697        QuantizeOntoGuard {
698            inner: Arc::new(Mutex::new(())),
699        }
700    }
701
702    /// Acquire the quantize drop guard to protect the critical section.
703    ///
704    /// On metal, this waits for outstanding work to finish to avoid "A command encoder is already encoding to this command buffer"
705    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
706        #[cfg(feature = "cuda")]
707        {
708            let _ = device;
709            QuantizeOntoDropGuard::Fake
710        }
711
712        #[cfg(not(feature = "cuda"))]
713        {
714            #[cfg(feature = "metal")]
715            if let Device::Metal(dev) = device {
716                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
717                dev.wait_until_completed()
718                    .expect("Failed to flush command buffer.");
719            }
720            #[cfg(not(feature = "metal"))]
721            let _ = device;
722
723            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
724        }
725    }
726}
727
728pub enum DistributedKind {
729    ColumnParallel,
730    RowParallel,
731    Replicated,
732}
733
734/// Quantized method for a quantized matmul.
735pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
736    fn new(method: QuantMethodConfig) -> Result<Self>
737    where
738        Self: Sized;
739
740    fn dequantize_w(&self) -> Result<Tensor>;
741
742    /// Compute matmul of `self` and `a`. `self` should contain the weights.
743    /// Automatically cast to required quantization activation type and back
744    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
745        let original_ty = a.dtype();
746        let a = if let Some(t) = self.quantized_act_type() {
747            a.to_dtype(t)?
748        } else {
749            a.clone()
750        };
751        self.forward(&a)?.to_dtype(original_ty)
752    }
753
754    /// Compute matmul of `self` and `a`. `self` should contain the weights.
755    fn forward(&self, a: &Tensor) -> Result<Tensor>;
756
757    /// Compute matmul of `self` and `a`. `self` should contain the weights.
758    /// Automatically cast to required quantization activation type and back.
759    ///
760    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
761    /// then the indices are (n_tokens, n_experts).
762    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
763        let original_ty = a.dtype();
764        let a = if let Some(t) = self.quantized_act_type() {
765            a.to_dtype(t)?
766        } else {
767            a.clone()
768        };
769        self.gather_forward(&a, indices)?.to_dtype(original_ty)
770    }
771
772    /// Compute matmul of `self` and `a`. `self` should contain the weights.
773    ///
774    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
775    /// then the indices are (n_tokens, n_experts).
776    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
777        candle_core::bail!(
778            "{} does not support `gather_forward`. Please raise an issue.",
779            self.name()
780        )
781    }
782
783    /// If a quantized method, return the activation dtype.
784    fn quantized_act_type(&self) -> Option<DType>;
785
786    /// Weight dtype and device
787    fn dtype_and_device(&self) -> (DType, Device);
788
789    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
790    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
791
792    /// If the quant is backed by a qmatmul.
793    fn apply_isq(
794        self: Arc<Self>,
795        dtype: Option<IsqType>,
796        device: Device,
797        n_quantized: &AtomicUsize,
798        imatrix_weight: Option<Vec<f32>>,
799        guard: QuantizeOntoGuard,
800    ) -> Result<Arc<dyn QuantMethod>>;
801
802    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
803        None
804    }
805
806    /// Begin tracking stats into an ImatrixLayerStats
807    fn begin_track_stats(&mut self) -> Result<()> {
808        candle_core::bail!("`{}` does not support tracking stats.", self.name())
809    }
810
811    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
812    fn end_track_stats(&self) -> Result<Tensor> {
813        candle_core::bail!("`{}` does not support tracking stats.", self.name())
814    }
815
816    fn is_distributed(&self) -> Option<DistributedKind> {
817        None
818    }
819}
820
821impl Module for dyn QuantMethod {
822    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
823        Self::forward(self, xs)
824    }
825}
826
827pub fn linear_no_bias(
828    in_dim: usize,
829    out_dim: usize,
830    config: &Option<QuantizedConfig>,
831    vb: ShardedVarBuilder,
832) -> Result<Arc<dyn QuantMethod>> {
833    let base_vb = vb.clone();
834    let vb = if should_apply_immediate_isq(&vb) {
835        vb.set_device(Device::Cpu)
836    } else {
837        vb
838    };
839
840    let layer = if let Some(quant_conf) = &config {
841        match quant_conf {
842            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
843            QuantizedConfig::Fp8 { .. } => {
844                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
845            }
846            QuantizedConfig::Bitsandbytes { .. } => {
847                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
848            }
849            QuantizedConfig::Afq { .. } => {
850                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
851            }
852            QuantizedConfig::MXFP4 {} => {
853                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
854            }
855        }
856    } else {
857        // Handle the case where the layer is dummy (no tensors)
858        if !vb.contains_tensor("weight") {
859            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
860            Arc::new(layer) as Arc<dyn QuantMethod>
861        } else {
862            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
863            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
864
865            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
866                Linear::new(weight, None),
867            ))?;
868            Arc::new(layer) as Arc<dyn QuantMethod>
869        }
870    };
871    apply_immediate_isq(layer, base_vb)
872}
873
874pub fn linear(
875    in_dim: usize,
876    out_dim: usize,
877    config: &Option<QuantizedConfig>,
878    vb: ShardedVarBuilder,
879) -> Result<Arc<dyn QuantMethod>> {
880    let base_vb = vb.clone();
881    let vb = if should_apply_immediate_isq(&vb) {
882        vb.set_device(Device::Cpu)
883    } else {
884        vb
885    };
886
887    let layer = if let Some(quant_conf) = &config {
888        match quant_conf {
889            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
890            QuantizedConfig::Fp8 { .. } => {
891                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
892            }
893            QuantizedConfig::Bitsandbytes { .. } => {
894                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
895            }
896            QuantizedConfig::Afq { .. } => {
897                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
898            }
899            QuantizedConfig::MXFP4 {} => {
900                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
901            }
902        }
903    } else {
904        // Handle the case where the layer is dummy (no tensors)
905        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
906            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
907            Arc::new(layer) as Arc<dyn QuantMethod>
908        } else {
909            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
910            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
911            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
912
913            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
914                Linear::new(weight, Some(bias)),
915            ))?;
916            Arc::new(layer) as Arc<dyn QuantMethod>
917        }
918    };
919    apply_immediate_isq(layer, base_vb)
920}
921
922pub fn linear_b(
923    in_dim: usize,
924    out_dim: usize,
925    bias: bool,
926    config: &Option<QuantizedConfig>,
927    vb: ShardedVarBuilder,
928) -> Result<Arc<dyn QuantMethod>> {
929    if bias {
930        linear(in_dim, out_dim, config, vb)
931    } else {
932        linear_no_bias(in_dim, out_dim, config, vb)
933    }
934}