mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
44pub use blockwise_fp8::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
45pub use distributed::{
46    layers::{
47        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
48        ReplicatedLayer, RowParallelLayer,
49    },
50    socket::{Client, Server},
51    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
52};
53pub use dummy::DummyLayer;
54pub use fp8::FP8Linear;
55pub use gguf::GgufMatMul;
56pub use gptq::GptqLayer;
57pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
58pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
59pub use lora::{
60    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
61    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
62};
63pub use mxfp4::MXFP4Layer;
64pub use unquantized::UnquantLinear;
65pub use utils::isq::apply_immediate_isq;
66pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
67pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
68
69use candle_nn::{Conv1d, Conv2d, Linear, Module};
70use serde::{Deserialize, Deserializer, Serialize};
71
72#[derive(Clone, Debug)]
73pub struct ImmediateIsqParams {
74    pub guard: QuantizeOntoGuard,
75    pub ty: Option<IsqType>,
76    pub predicates: Vec<Regex>,
77    pub overrides: Vec<ImmediateIsqOverride>,
78}
79
80#[derive(Clone, Debug)]
81pub struct ImmediateIsqOverride {
82    pub predicate: Regex,
83    pub ty: Option<IsqType>,
84    pub device: Option<Device>,
85}
86
87#[derive(Clone, Debug)]
88pub struct ImmediateIsqMatch {
89    pub ty: IsqType,
90    pub device: Option<Device>,
91}
92
93thread_local! {
94    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
95}
96
97pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
98    set_immediate_isq_with_overrides(isq, predicates, Vec::new());
99}
100
101pub fn set_immediate_isq_with_overrides(
102    isq: Option<IsqType>,
103    predicates: Vec<Regex>,
104    overrides: Vec<ImmediateIsqOverride>,
105) {
106    ENGINE_IMMEDIATE_ISQ.with(|cell| {
107        *cell.borrow_mut() = Some(ImmediateIsqParams {
108            guard: QuantizeOntoGuard::new(),
109            ty: isq,
110            predicates,
111            overrides,
112        });
113    });
114}
115
116pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
117    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
118}
119
120pub fn clear_immediate_isq() {
121    ENGINE_IMMEDIATE_ISQ.with(|cell| {
122        *cell.borrow_mut() = None;
123    });
124}
125
126pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
127    immediate_isq_match(vb).is_some()
128}
129
130pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
131    let immediate_isq = get_immediate_isq()?;
132    // Add a .weight to match the ISQ regexes!
133    let prefix = format!("{}.weight", vb.prefix());
134    resolve_immediate_isq(&immediate_isq, &prefix)
135}
136
137fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
138    if let Some(override_hit) = params
139        .overrides
140        .iter()
141        .find(|override_pred| override_pred.predicate.is_match(prefix))
142    {
143        if let Some(ty) = override_hit.ty.or(params.ty) {
144            return Some(ImmediateIsqMatch {
145                ty,
146                device: override_hit.device.clone(),
147            });
148        }
149        return None;
150    }
151
152    if let Some(ty) = params.ty {
153        if params
154            .predicates
155            .iter()
156            .any(|predicate| predicate.is_match(prefix))
157        {
158            return Some(ImmediateIsqMatch { ty, device: None });
159        }
160    }
161
162    None
163}
164
165#[derive(Debug, Clone, Serialize)]
166#[serde(tag = "quant_method", rename_all = "lowercase")]
167pub enum QuantizedConfig {
168    GptqAwq {
169        bits: usize,
170        group_size: usize,
171        checkpoint_format: Option<String>,
172        is_awq: bool,
173    },
174    Fp8 {
175        weight_block_size: Vec<usize>,
176    },
177    Bitsandbytes {
178        bnb_4bit_quant_type: Option<String>,
179    },
180    Afq {
181        bits: usize,
182        group_size: usize,
183    },
184    MXFP4 {},
185}
186
187// Common fields for all variants
188#[derive(Deserialize)]
189struct RawConfig {
190    quant_method: Option<String>,
191    bits: Option<usize>,
192    group_size: Option<usize>,
193    checkpoint_format: Option<String>,
194    weight_block_size: Option<Vec<usize>>,
195    bnb_4bit_quant_type: Option<String>,
196}
197
198// Custom deserializer implementation
199impl<'de> Deserialize<'de> for QuantizedConfig {
200    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
201    where
202        D: Deserializer<'de>,
203    {
204        let raw = RawConfig::deserialize(deserializer)?;
205
206        match &raw.quant_method {
207            Some(m) if m == "gptq" || m == "awq" => {
208                let bits = raw
209                    .bits
210                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
211                let group_size = raw
212                    .group_size
213                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
214                Ok(QuantizedConfig::GptqAwq {
215                    bits,
216                    group_size,
217                    checkpoint_format: raw.checkpoint_format,
218                    is_awq: m == "awq",
219                })
220            }
221            Some(m) if m == "fp8" => {
222                let weight_block_size = raw
223                    .weight_block_size
224                    .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
225                Ok(QuantizedConfig::Fp8 { weight_block_size })
226            }
227            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
228                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
229            }),
230            Some(m) if m == "afq" => {
231                let bits = raw
232                    .bits
233                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
234                let group_size = raw
235                    .group_size
236                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
237                Ok(QuantizedConfig::Afq { bits, group_size })
238            }
239            Some(m) if m == "mxfp4" => {
240                Ok(QuantizedConfig::MXFP4 {  })
241            }
242            None => {
243                let bits = raw
244                    .bits
245                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
246                let group_size = raw
247                    .group_size
248                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
249                Ok(QuantizedConfig::Afq { bits, group_size })
250            }
251            Some(unknown_method) => {
252                Err(serde::de::Error::custom(format!(
253                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
254                )))
255            },
256        }
257    }
258}
259
260impl QuantizedConfig {
261    pub fn name(&self) -> &'static str {
262        match self {
263            Self::GptqAwq { .. } => "gptq",
264            Self::Fp8 { .. } => "fp8",
265            Self::Bitsandbytes { .. } => "bitsandbytes",
266            Self::Afq { .. } => "afq",
267            Self::MXFP4 { .. } => "mxfp4",
268        }
269    }
270
271    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
272        match self {
273            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
274            Self::Fp8 { .. } => "8 bits".to_string(),
275            Self::Bitsandbytes {
276                bnb_4bit_quant_type: Some(_),
277            } => "4 bits".to_string(),
278            Self::Bitsandbytes {
279                bnb_4bit_quant_type: None,
280            } => "8 bits".to_string(),
281            Self::Afq { bits, .. } => format!("{bits} bits"),
282            Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
283        }
284    }
285
286    pub fn pack_factor(&self, dtype: DType) -> usize {
287        match self {
288            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
289                2 => IsqType::Q2K.pack_factor(dtype),
290                3 => IsqType::Q3K.pack_factor(dtype),
291                4 => IsqType::Q4K.pack_factor(dtype),
292                5 => IsqType::Q5K.pack_factor(dtype),
293                6 => IsqType::Q6K.pack_factor(dtype),
294                8 => IsqType::Q8_0.pack_factor(dtype),
295                40 => 4, // mxfp4: 2 FP4 values per byte = factor of 4
296                other => panic!("Unexpected bits in `pack_factor` {other}"),
297            },
298            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
299            Self::Bitsandbytes {
300                bnb_4bit_quant_type: Some(_),
301            }
302            | Self::Bitsandbytes {
303                bnb_4bit_quant_type: None,
304            } => IsqType::Q4K.pack_factor(dtype),
305            Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
306        }
307    }
308}
309
310#[derive(Debug, Clone)]
311pub enum QuantMethodConfig {
312    GptqAwq {
313        bits: i32,
314        use_exllama: bool,
315        q_weight: Tensor,
316        qzeros: Option<Tensor>,
317        scales: Tensor,
318        g_idx: Option<Tensor>,
319        bias: Option<Tensor>,
320        workspace: Option<Tensor>,
321        is_marlin: bool,
322        is_awq: bool,
323    },
324    Gguf {
325        q_weight: Arc<QTensor>,
326        b: Option<Tensor>,
327    },
328    Unquantized(Linear),
329    Hqq {
330        tensor: Tensor,
331        bits: HqqBits,
332        group_size: NonZeroUsize,
333        axis: HqqAxis,
334        optimization_steps: Option<usize>,
335        round_zeros: Option<bool>,
336        channel_wise: Option<bool>,
337        bias: Option<Tensor>,
338    },
339    Dummy,
340    FP8 {
341        lin: Linear,
342        dtype: DType,
343    },
344    Bnb {
345        weight: Tensor,
346        bias: Option<Tensor>,
347        params: BnbQuantParams,
348        quant_ty: BnbQuantType,
349    },
350    BlockwiseFP8 {
351        weight: Tensor,
352        weight_scale_inv: Tensor,
353        bias: Option<Tensor>,
354        dequant_dtype: DType,
355        weight_block_size: Vec<usize>,
356    },
357    Afq {
358        weight: Tensor,
359        bias: Option<Tensor>,
360        bits: AfqBits,
361        group_size: AfqGroupSize,
362    },
363    MXFP4 {
364        blocks: Tensor,
365        scales: Tensor,
366        bias: Option<Tensor>,
367    },
368}
369
370/// Device/configurable intelligent matrix multiplication
371/// - Handles limitation of `accelerate` which requires f32
372pub struct MatMul;
373
374impl MatMul {
375    /// Compute matrix-matrix product.
376    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
377        #[cfg(feature = "accelerate")]
378        {
379            let original_dtype = a.dtype();
380            a.to_dtype(DType::F32)?
381                .matmul(&b.to_dtype(DType::F32)?)?
382                .to_dtype(original_dtype)
383        }
384        #[cfg(not(feature = "accelerate"))]
385        {
386            if a.device().is_cpu() {
387                let original_dtype = a.dtype();
388                a.to_dtype(DType::F16)?
389                    .matmul(&b.to_dtype(DType::F16)?)?
390                    .to_dtype(original_dtype)
391            } else {
392                a.matmul(b)
393            }
394        }
395    }
396
397    /// Compute matrix-matrix product.
398    /// The result will be divided by the `scale` parameter in an affine division.
399    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
400        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
401        self.matmul(a, b)? / scale
402    }
403
404    /// Compute matrix-matrix product.
405    /// The result will be divided by the `scale` parameter in an affine multiplication.
406    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
407        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
408        self.matmul(a, b)? * scale
409    }
410
411    /// Compute quantized matrix-matrix product.
412    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
413        matmul.forward(x)
414    }
415
416    /// Compute quantized matrix-matrix product.
417    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
418        matmul.forward(x)
419    }
420}
421
422/// Device/configurable intelligent convolution
423/// - Handles limitation of cpu which requires f32
424pub struct Convolution;
425
426impl Convolution {
427    pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
428        if x.device().is_cpu() {
429            let original_dtype = x.dtype();
430            Conv1d::new(
431                layer.weight().to_dtype(DType::F32)?,
432                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
433                *layer.config(),
434            )
435            .forward(&x.to_dtype(DType::F32)?)?
436            .to_dtype(original_dtype)
437        } else {
438            layer.forward(x)
439        }
440    }
441
442    pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
443        if x.device().is_cpu() {
444            let original_dtype = x.dtype();
445            Conv2d::new(
446                layer.weight().to_dtype(DType::F32)?,
447                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
448                *layer.config(),
449            )
450            .forward(&x.to_dtype(DType::F32)?)?
451            .to_dtype(original_dtype)
452        } else {
453            layer.forward(x)
454        }
455    }
456}
457
458#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
459pub enum IsqType {
460    Q4_0,
461    Q4_1,
462    Q5_0,
463    Q5_1,
464    Q8_0,
465    Q8_1,
466    Q2K,
467    Q3K,
468    Q4K,
469    Q5K,
470    Q6K,
471    Q8K,
472    HQQ8,
473    HQQ4,
474    // HQQ3,
475    // HQQ2,
476    // HQQ1,
477    F8E4M3,
478    AFQ8,
479    AFQ6,
480    AFQ4,
481    AFQ3,
482    AFQ2,
483}
484
485impl IsqType {
486    /// Factor by which the weight size is reduced over the given dtype.
487    /// original size / pack factor = quantized size
488    pub fn pack_factor(&self, dtype: DType) -> usize {
489        match self {
490            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
491                .div_ceil(GgmlDType::Q4_0.type_size()),
492            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
493                .div_ceil(GgmlDType::Q4_1.type_size()),
494            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
495                .div_ceil(GgmlDType::Q5_0.type_size()),
496            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
497                .div_ceil(GgmlDType::Q5_1.type_size()),
498            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
499                .div_ceil(GgmlDType::Q8_0.type_size()),
500            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
501                .div_ceil(GgmlDType::Q8_1.type_size()),
502            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
503                .div_ceil(GgmlDType::Q2K.type_size()),
504            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
505                .div_ceil(GgmlDType::Q3K.type_size()),
506            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
507                .div_ceil(GgmlDType::Q4K.type_size()),
508            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
509                .div_ceil(GgmlDType::Q5K.type_size()),
510            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
511                .div_ceil(GgmlDType::Q6K.type_size()),
512            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
513                .div_ceil(GgmlDType::Q8K.type_size()),
514            // Estimates
515            Self::HQQ4 => 4,
516            Self::HQQ8 => 2,
517            Self::F8E4M3 => 2,
518        }
519    }
520
521    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
522        match self {
523            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
524            IsqType::HQQ4
525            | IsqType::HQQ8
526            | IsqType::AFQ2
527            | IsqType::AFQ3
528            | IsqType::AFQ4
529            | IsqType::AFQ6
530            | IsqType::AFQ8 => {
531                // Use 1 because our HQQ quantizes on the GPU
532                Some(1.try_into().unwrap())
533            }
534            IsqType::F8E4M3 => None,
535            IsqType::Q2K
536            | IsqType::Q3K
537            | IsqType::Q4K
538            | IsqType::Q4_0
539            | IsqType::Q4_1
540            | IsqType::Q5K
541            | IsqType::Q5_0
542            | IsqType::Q5_1
543            | IsqType::Q6K
544            | IsqType::Q8K
545            | IsqType::Q8_0
546            | IsqType::Q8_1 => None,
547        }
548    }
549}
550
551impl TryFrom<IsqType> for GgmlDType {
552    type Error = candle_core::Error;
553
554    fn try_from(value: IsqType) -> Result<Self> {
555        let tp = match value {
556            IsqType::Q2K => Self::Q2K,
557            IsqType::Q3K => Self::Q3K,
558            IsqType::Q4K => Self::Q4K,
559            IsqType::Q4_0 => Self::Q4_0,
560            IsqType::Q4_1 => Self::Q4_1,
561            IsqType::Q5K => Self::Q5K,
562            IsqType::Q5_0 => Self::Q5_0,
563            IsqType::Q5_1 => Self::Q5_1,
564            IsqType::Q6K => Self::Q6K,
565            IsqType::Q8K => Self::Q8K,
566            IsqType::Q8_0 => Self::Q8_0,
567            IsqType::Q8_1 => Self::Q8_1,
568            _ => candle_core::bail!("Expected valid GGML ISQ type."),
569        };
570        #[cfg(feature = "cuda")]
571        {
572            if !matches!(
573                tp,
574                GgmlDType::Q4_0
575                    | GgmlDType::Q4_1
576                    | GgmlDType::Q5_0
577                    | GgmlDType::Q5_1
578                    | GgmlDType::Q8_0
579                    | GgmlDType::Q2K
580                    | GgmlDType::Q3K
581                    | GgmlDType::Q4K
582                    | GgmlDType::Q5K
583                    | GgmlDType::Q6K
584            ) {
585                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
586            }
587        }
588        Ok(tp)
589    }
590}
591
592impl TryFrom<GgmlDType> for IsqType {
593    type Error = candle_core::Error;
594
595    fn try_from(value: GgmlDType) -> Result<Self> {
596        match value {
597            GgmlDType::Q2K => Ok(Self::Q2K),
598            GgmlDType::Q3K => Ok(Self::Q3K),
599            GgmlDType::Q4K => Ok(Self::Q4K),
600            GgmlDType::Q5K => Ok(Self::Q5K),
601            GgmlDType::Q6K => Ok(Self::Q6K),
602            GgmlDType::Q4_0 => Ok(Self::Q4_0),
603            GgmlDType::Q4_1 => Ok(Self::Q4_1),
604            GgmlDType::Q5_0 => Ok(Self::Q5_0),
605            GgmlDType::Q5_1 => Ok(Self::Q5_1),
606            GgmlDType::Q8_0 => Ok(Self::Q8_0),
607            GgmlDType::Q8_1 => Ok(Self::Q8_1),
608            GgmlDType::Q8K => Ok(Self::Q8K),
609            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
610                candle_core::bail!("Expected valid GGML ISQ type.")
611            }
612        }
613    }
614}
615
616#[derive(Debug, Clone, Copy)]
617pub enum QuantizedSerdeType {
618    Gguf = 0,
619    Unquant = 1,
620    Hqq = 2,
621    Fp8 = 3,
622    Afq = 4,
623}
624
625impl TryFrom<usize> for QuantizedSerdeType {
626    type Error = candle_core::Error;
627    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
628        match value {
629            0 => Ok(Self::Gguf),
630            1 => Ok(Self::Unquant),
631            2 => Ok(Self::Hqq),
632            3 => Ok(Self::Fp8),
633            4 => Ok(Self::Afq),
634            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
635        }
636    }
637}
638
639pub trait QuantizedSerde {
640    fn name(&self) -> &'static str;
641    fn isq_serde_supported(&self) -> bool {
642        false
643    }
644    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
645        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
646    }
647    fn deserialize(
648        _data: Cow<[u8]>,
649        _device: &Device,
650        _comm: &Arc<crate::Comm>,
651        _guard: QuantizeOntoGuard,
652    ) -> Result<Arc<dyn QuantMethod>>
653    where
654        Self: Sized,
655    {
656        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
657    }
658    fn deserialize_ext_bias(
659        _data: Cow<[u8]>,
660        _device: &Device,
661        _guard: QuantizeOntoGuard,
662    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
663    where
664        Self: Sized,
665    {
666        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
667    }
668    /// NOT meant for external calling
669    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
670        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
671    }
672}
673
674/// Used to gate access to quantizing onto the host device
675#[derive(Clone, Debug)]
676#[allow(unused)]
677pub struct QuantizeOntoGuard {
678    pub inner: Arc<Mutex<()>>,
679}
680
681/// Real (for Metal) and Fake (for CUDA)
682pub enum QuantizeOntoDropGuard<'a> {
683    Real(MutexGuard<'a, ()>),
684    Fake,
685}
686
687impl Default for QuantizeOntoGuard {
688    fn default() -> Self {
689        Self::new()
690    }
691}
692
693impl QuantizeOntoGuard {
694    pub fn new() -> Self {
695        QuantizeOntoGuard {
696            inner: Arc::new(Mutex::new(())),
697        }
698    }
699
700    /// Acquire the quantize drop guard to protect the critical section.
701    ///
702    /// On metal, this flushes the command buffer to avoid "A command encoder is already encoding to this command buffer"
703    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
704        #[cfg(feature = "cuda")]
705        {
706            let _ = device;
707            QuantizeOntoDropGuard::Fake
708        }
709
710        #[cfg(not(feature = "cuda"))]
711        {
712            #[cfg(feature = "metal")]
713            if let Device::Metal(dev) = device {
714                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
715                dev.flush_command_buffer()
716                    .expect("Failed to flush command buffer.");
717            }
718            #[cfg(not(feature = "metal"))]
719            let _ = device;
720
721            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
722        }
723    }
724}
725
726pub enum DistributedKind {
727    ColumnParallel,
728    RowParallel,
729    Replicated,
730}
731
732/// Quantized method for a quantized matmul.
733pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
734    fn new(method: QuantMethodConfig) -> Result<Self>
735    where
736        Self: Sized;
737
738    fn dequantize_w(&self) -> Result<Tensor>;
739
740    /// Compute matmul of `self` and `a`. `self` should contain the weights.
741    /// Automatically cast to required quantization activation type and back
742    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
743        let original_ty = a.dtype();
744        let a = if let Some(t) = self.quantized_act_type() {
745            a.to_dtype(t)?
746        } else {
747            a.clone()
748        };
749        self.forward(&a)?.to_dtype(original_ty)
750    }
751
752    /// Compute matmul of `self` and `a`. `self` should contain the weights.
753    fn forward(&self, a: &Tensor) -> Result<Tensor>;
754
755    /// Compute matmul of `self` and `a`. `self` should contain the weights.
756    /// Automatically cast to required quantization activation type and back.
757    ///
758    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
759    /// then the indices are (n_tokens, n_experts).
760    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
761        let original_ty = a.dtype();
762        let a = if let Some(t) = self.quantized_act_type() {
763            a.to_dtype(t)?
764        } else {
765            a.clone()
766        };
767        self.gather_forward(&a, indices)?.to_dtype(original_ty)
768    }
769
770    /// Compute matmul of `self` and `a`. `self` should contain the weights.
771    ///
772    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
773    /// then the indices are (n_tokens, n_experts).
774    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
775        candle_core::bail!(
776            "{} does not support `gather_forward`. Please raise an issue.",
777            self.name()
778        )
779    }
780
781    /// If a quantized method, return the activation dtype.
782    fn quantized_act_type(&self) -> Option<DType>;
783
784    /// Weight dtype and device
785    fn dtype_and_device(&self) -> (DType, Device);
786
787    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
788    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
789
790    /// If the quant is backed by a qmatmul.
791    fn apply_isq(
792        self: Arc<Self>,
793        dtype: Option<IsqType>,
794        device: Device,
795        n_quantized: &AtomicUsize,
796        imatrix_weight: Option<Vec<f32>>,
797        guard: QuantizeOntoGuard,
798    ) -> Result<Arc<dyn QuantMethod>>;
799
800    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
801        None
802    }
803
804    /// Begin tracking stats into an ImatrixLayerStats
805    fn begin_track_stats(&mut self) -> Result<()> {
806        candle_core::bail!("`{}` does not support tracking stats.", self.name())
807    }
808
809    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
810    fn end_track_stats(&self) -> Result<Tensor> {
811        candle_core::bail!("`{}` does not support tracking stats.", self.name())
812    }
813
814    fn is_distributed(&self) -> Option<DistributedKind> {
815        None
816    }
817}
818
819impl Module for dyn QuantMethod {
820    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
821        Self::forward(self, xs)
822    }
823}
824
825pub fn linear_no_bias(
826    in_dim: usize,
827    out_dim: usize,
828    config: &Option<QuantizedConfig>,
829    vb: ShardedVarBuilder,
830) -> Result<Arc<dyn QuantMethod>> {
831    let base_vb = vb.clone();
832    let vb = if should_apply_immediate_isq(&vb) {
833        vb.set_device(Device::Cpu)
834    } else {
835        vb
836    };
837
838    let layer = if let Some(quant_conf) = &config {
839        match quant_conf {
840            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
841            QuantizedConfig::Fp8 { .. } => {
842                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
843            }
844            QuantizedConfig::Bitsandbytes { .. } => {
845                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
846            }
847            QuantizedConfig::Afq { .. } => {
848                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
849            }
850            QuantizedConfig::MXFP4 {} => {
851                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
852            }
853        }
854    } else {
855        // Handle the case where the layer is dummy (no tensors)
856        if !vb.contains_tensor("weight") {
857            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
858            Arc::new(layer) as Arc<dyn QuantMethod>
859        } else {
860            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
861            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
862
863            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
864                Linear::new(weight, None),
865            ))?;
866            Arc::new(layer) as Arc<dyn QuantMethod>
867        }
868    };
869    apply_immediate_isq(layer, base_vb)
870}
871
872pub fn linear(
873    in_dim: usize,
874    out_dim: usize,
875    config: &Option<QuantizedConfig>,
876    vb: ShardedVarBuilder,
877) -> Result<Arc<dyn QuantMethod>> {
878    let base_vb = vb.clone();
879    let vb = if should_apply_immediate_isq(&vb) {
880        vb.set_device(Device::Cpu)
881    } else {
882        vb
883    };
884
885    let layer = if let Some(quant_conf) = &config {
886        match quant_conf {
887            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
888            QuantizedConfig::Fp8 { .. } => {
889                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
890            }
891            QuantizedConfig::Bitsandbytes { .. } => {
892                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
893            }
894            QuantizedConfig::Afq { .. } => {
895                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
896            }
897            QuantizedConfig::MXFP4 {} => {
898                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
899            }
900        }
901    } else {
902        // Handle the case where the layer is dummy (no tensors)
903        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
904            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
905            Arc::new(layer) as Arc<dyn QuantMethod>
906        } else {
907            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
908            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
909            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
910
911            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
912                Linear::new(weight, Some(bias)),
913            ))?;
914            Arc::new(layer) as Arc<dyn QuantMethod>
915        }
916    };
917    apply_immediate_isq(layer, base_vb)
918}
919
920pub fn linear_b(
921    in_dim: usize,
922    out_dim: usize,
923    bias: bool,
924    config: &Option<QuantizedConfig>,
925    vb: ShardedVarBuilder,
926) -> Result<Arc<dyn QuantMethod>> {
927    if bias {
928        linear(in_dim, out_dim, config, vb)
929    } else {
930        linear_no_bias(in_dim, out_dim, config, vb)
931    }
932}