mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29pub mod rotary;
30pub mod safetensors;
31mod unquantized;
32mod utils;
33
34use gptq::gptq_linear;
35use lora::merge_lora_weights;
36use regex::Regex;
37pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
38
39pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
40pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
41pub use distributed::{
42    layers::{
43        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
44        ReplicatedLayer, RowParallelLayer,
45    },
46    socket::{Client, Server},
47    BarrierLike, Comm, Id, SumAllReduce,
48};
49pub use dummy::DummyLayer;
50pub use fp8::FP8Linear;
51pub use gguf::GgufMatMul;
52pub use gptq::GptqLayer;
53pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
54pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
55pub use lora::{
56    linear_no_bias_static_lora, LoraAdapter, LoraConfig, StaticLoraConfig, APPLIED_LORAS,
57    MULTI_LORA_DELIMITER,
58};
59pub use unquantized::UnquantLinear;
60pub use utils::isq::apply_immediate_isq;
61pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
62
63use candle_nn::{Linear, Module};
64use serde::{Deserialize, Deserializer, Serialize};
65
66#[derive(Clone, Debug)]
67pub struct ImmediateIsqParams {
68    pub guard: QuantizeOntoGuard,
69    pub ty: Option<IsqType>,
70    pub predicates: Vec<Regex>,
71}
72
73static IMMEDIATE_ISQ: Mutex<Option<ImmediateIsqParams>> = Mutex::new(None);
74
75pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
76    let mut guard = IMMEDIATE_ISQ.lock().expect("IMMEDIATE_ISQ mutex poisoned");
77    *guard = Some(ImmediateIsqParams {
78        guard: QuantizeOntoGuard::new(),
79        ty: isq,
80        predicates,
81    });
82}
83
84pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
85    IMMEDIATE_ISQ.lock().ok().and_then(|guard| guard.clone())
86}
87
88pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
89    let Some(immediate_isq) = get_immediate_isq() else {
90        return false;
91    };
92    // Add a .weight to match the ISQ regexes!
93    let prefix = format!("{}.weight", vb.prefix());
94    immediate_isq.ty.is_some()
95        && immediate_isq
96            .predicates
97            .iter()
98            .any(|predicate| predicate.is_match(&prefix))
99}
100
101#[derive(Debug, Clone, Serialize)]
102#[serde(tag = "quant_method", rename_all = "lowercase")]
103pub enum QuantizedConfig {
104    GptqAwq {
105        bits: usize,
106        group_size: usize,
107        checkpoint_format: Option<String>,
108        is_awq: bool,
109    },
110    Fp8 {
111        weight_block_size: Vec<usize>,
112    },
113    Bitsandbytes {
114        bnb_4bit_quant_type: Option<String>,
115    },
116    Afq {
117        bits: usize,
118        group_size: usize,
119    },
120}
121
122// Common fields for all variants
123#[derive(Deserialize)]
124struct RawConfig {
125    quant_method: Option<String>,
126    bits: Option<usize>,
127    group_size: Option<usize>,
128    checkpoint_format: Option<String>,
129    weight_block_size: Option<Vec<usize>>,
130    bnb_4bit_quant_type: Option<String>,
131}
132
133// Custom deserializer implementation
134impl<'de> Deserialize<'de> for QuantizedConfig {
135    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
136    where
137        D: Deserializer<'de>,
138    {
139        let raw = RawConfig::deserialize(deserializer)?;
140
141        match &raw.quant_method {
142            Some(m) if m == "gptq" || m == "awq" => {
143                let bits = raw
144                    .bits
145                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
146                let group_size = raw
147                    .group_size
148                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
149                Ok(QuantizedConfig::GptqAwq {
150                    bits,
151                    group_size,
152                    checkpoint_format: raw.checkpoint_format,
153                    is_awq: m == "awq",
154                })
155            }
156            Some(m) if m == "fp8" => {
157                let weight_block_size = raw
158                    .weight_block_size
159                    .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
160                Ok(QuantizedConfig::Fp8 { weight_block_size })
161            }
162            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
163                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
164            }),
165            Some(m) if m == "afq" => {
166                let bits = raw
167                    .bits
168                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
169                let group_size = raw
170                    .group_size
171                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
172                Ok(QuantizedConfig::Afq { bits, group_size })
173            }
174            None => {
175                let bits = raw
176                    .bits
177                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
178                let group_size = raw
179                    .group_size
180                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
181                Ok(QuantizedConfig::Afq { bits, group_size })
182            }
183            Some(unknown_method) => {
184                Err(serde::de::Error::custom(format!(
185                    "Unknown quantization method: {}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified", 
186                    unknown_method
187                )))
188            },
189        }
190    }
191}
192
193impl QuantizedConfig {
194    pub fn name(&self) -> &'static str {
195        match self {
196            Self::GptqAwq { .. } => "gptq",
197            Self::Fp8 { .. } => "fp8",
198            Self::Bitsandbytes { .. } => "bitsandbytes",
199            Self::Afq { .. } => "afq",
200        }
201    }
202
203    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
204        match self {
205            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
206            Self::Fp8 { .. } => "8 bits".to_string(),
207            Self::Bitsandbytes {
208                bnb_4bit_quant_type: Some(_),
209            } => "4 bits".to_string(),
210            Self::Bitsandbytes {
211                bnb_4bit_quant_type: None,
212            } => "8 bits".to_string(),
213            Self::Afq { bits, .. } => format!("{bits} bits"),
214        }
215    }
216
217    pub fn pack_factor(&self, dtype: DType) -> usize {
218        match self {
219            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
220                2 => IsqType::Q2K.pack_factor(dtype),
221                3 => IsqType::Q3K.pack_factor(dtype),
222                4 => IsqType::Q4K.pack_factor(dtype),
223                5 => IsqType::Q5K.pack_factor(dtype),
224                6 => IsqType::Q6K.pack_factor(dtype),
225                8 => IsqType::Q8_0.pack_factor(dtype),
226                other => panic!("Unexpected bits in `pack_factor` {other}"),
227            },
228            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
229            Self::Bitsandbytes {
230                bnb_4bit_quant_type: Some(_),
231            }
232            | Self::Bitsandbytes {
233                bnb_4bit_quant_type: None,
234            } => IsqType::Q4K.pack_factor(dtype),
235        }
236    }
237}
238
239#[derive(Debug, Clone)]
240pub enum QuantMethodConfig {
241    GptqAwq {
242        bits: i32,
243        use_exllama: bool,
244        q_weight: Tensor,
245        qzeros: Option<Tensor>,
246        scales: Tensor,
247        g_idx: Option<Tensor>,
248        bias: Option<Tensor>,
249        workspace: Option<Tensor>,
250        is_marlin: bool,
251        is_awq: bool,
252    },
253    Gguf {
254        q_weight: Arc<QTensor>,
255        b: Option<Tensor>,
256    },
257    Unquantized(Linear),
258    Hqq {
259        tensor: Tensor,
260        bits: HqqBits,
261        group_size: NonZeroUsize,
262        axis: HqqAxis,
263        optimization_steps: Option<usize>,
264        round_zeros: Option<bool>,
265        channel_wise: Option<bool>,
266        bias: Option<Tensor>,
267    },
268    Dummy,
269    FP8 {
270        lin: Linear,
271        dtype: DType,
272    },
273    Bnb {
274        weight: Tensor,
275        bias: Option<Tensor>,
276        params: BnbQuantParmas,
277        quant_ty: BnbQuantType,
278    },
279    BlockwiseFP8 {
280        weight: Tensor,
281        weight_scale_inv: Tensor,
282        bias: Option<Tensor>,
283        dequant_dtype: DType,
284        weight_block_size: Vec<usize>,
285    },
286    Afq {
287        weight: Tensor,
288        bias: Option<Tensor>,
289        bits: AfqBits,
290        group_size: AfqGroupSize,
291    },
292}
293
294/// Device/configurable intelligent matrix multiplication
295/// - Handles limitation of `accelerate` which requires f32
296pub struct MatMul;
297
298impl MatMul {
299    /// Compute matrix-matrix product.
300    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
301        #[cfg(feature = "accelerate")]
302        {
303            let original_dtype = a.dtype();
304            a.to_dtype(DType::F32)?
305                .matmul(&b.to_dtype(DType::F32)?)?
306                .to_dtype(original_dtype)
307        }
308        #[cfg(not(feature = "accelerate"))]
309        {
310            if a.device().is_cpu() {
311                let original_dtype = a.dtype();
312                a.to_dtype(DType::F16)?
313                    .matmul(&b.to_dtype(DType::F16)?)?
314                    .to_dtype(original_dtype)
315            } else {
316                a.matmul(b)
317            }
318        }
319    }
320
321    /// Compute matrix-matrix product.
322    /// The result will be divided by the `scale` parameter in an affine division.
323    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
324        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
325        self.matmul(a, b)? / scale
326    }
327
328    /// Compute matrix-matrix product.
329    /// The result will be divided by the `scale` parameter in an affine multiplication.
330    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
331        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
332        self.matmul(a, b)? * scale
333    }
334
335    /// Compute quantized matrix-matrix product.
336    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
337        matmul.forward(x)
338    }
339
340    /// Compute quantized matrix-matrix product.
341    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
342        matmul.forward(x)
343    }
344}
345
346#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
347pub enum IsqType {
348    Q4_0,
349    Q4_1,
350    Q5_0,
351    Q5_1,
352    Q8_0,
353    Q8_1,
354    Q2K,
355    Q3K,
356    Q4K,
357    Q5K,
358    Q6K,
359    Q8K,
360    HQQ8,
361    HQQ4,
362    // HQQ3,
363    // HQQ2,
364    // HQQ1,
365    F8E4M3,
366    AFQ8,
367    AFQ6,
368    AFQ4,
369    AFQ3,
370    AFQ2,
371}
372
373impl IsqType {
374    /// Factor by which the weight size is reduced over the given dtype.
375    /// original size / pack factor = quantized size
376    pub fn pack_factor(&self, dtype: DType) -> usize {
377        match self {
378            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
379                .div_ceil(GgmlDType::Q4_0.type_size()),
380            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
381                .div_ceil(GgmlDType::Q4_1.type_size()),
382            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
383                .div_ceil(GgmlDType::Q5_0.type_size()),
384            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
385                .div_ceil(GgmlDType::Q5_1.type_size()),
386            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
387                .div_ceil(GgmlDType::Q8_0.type_size()),
388            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
389                .div_ceil(GgmlDType::Q8_1.type_size()),
390            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
391                .div_ceil(GgmlDType::Q2K.type_size()),
392            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
393                .div_ceil(GgmlDType::Q3K.type_size()),
394            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
395                .div_ceil(GgmlDType::Q4K.type_size()),
396            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
397                .div_ceil(GgmlDType::Q5K.type_size()),
398            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
399                .div_ceil(GgmlDType::Q6K.type_size()),
400            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
401                .div_ceil(GgmlDType::Q8K.type_size()),
402            // Estimates
403            Self::HQQ4 => 4,
404            Self::HQQ8 => 2,
405            Self::F8E4M3 => 2,
406        }
407    }
408
409    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
410        match self {
411            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
412            IsqType::HQQ4
413            | IsqType::HQQ8
414            | IsqType::AFQ2
415            | IsqType::AFQ3
416            | IsqType::AFQ4
417            | IsqType::AFQ6
418            | IsqType::AFQ8 => {
419                // Use 1 because our HQQ quantizes on the GPU
420                Some(1.try_into().unwrap())
421            }
422            IsqType::F8E4M3 => None,
423            IsqType::Q2K
424            | IsqType::Q3K
425            | IsqType::Q4K
426            | IsqType::Q4_0
427            | IsqType::Q4_1
428            | IsqType::Q5K
429            | IsqType::Q5_0
430            | IsqType::Q5_1
431            | IsqType::Q6K
432            | IsqType::Q8K
433            | IsqType::Q8_0
434            | IsqType::Q8_1 => None,
435        }
436    }
437}
438
439impl TryFrom<IsqType> for GgmlDType {
440    type Error = candle_core::Error;
441
442    fn try_from(value: IsqType) -> Result<Self> {
443        let tp = match value {
444            IsqType::Q2K => Self::Q2K,
445            IsqType::Q3K => Self::Q3K,
446            IsqType::Q4K => Self::Q4K,
447            IsqType::Q4_0 => Self::Q4_0,
448            IsqType::Q4_1 => Self::Q4_1,
449            IsqType::Q5K => Self::Q5K,
450            IsqType::Q5_0 => Self::Q5_0,
451            IsqType::Q5_1 => Self::Q5_1,
452            IsqType::Q6K => Self::Q6K,
453            IsqType::Q8K => Self::Q8K,
454            IsqType::Q8_0 => Self::Q8_0,
455            IsqType::Q8_1 => Self::Q8_1,
456            _ => candle_core::bail!("Expected valid GGML ISQ type."),
457        };
458        #[cfg(feature = "cuda")]
459        {
460            if !matches!(
461                tp,
462                GgmlDType::Q4_0
463                    | GgmlDType::Q4_1
464                    | GgmlDType::Q5_0
465                    | GgmlDType::Q5_1
466                    | GgmlDType::Q8_0
467                    | GgmlDType::Q2K
468                    | GgmlDType::Q3K
469                    | GgmlDType::Q4K
470                    | GgmlDType::Q5K
471                    | GgmlDType::Q6K
472            ) {
473                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
474            }
475        }
476        Ok(tp)
477    }
478}
479
480impl TryFrom<GgmlDType> for IsqType {
481    type Error = candle_core::Error;
482
483    fn try_from(value: GgmlDType) -> Result<Self> {
484        match value {
485            GgmlDType::Q2K => Ok(Self::Q2K),
486            GgmlDType::Q3K => Ok(Self::Q3K),
487            GgmlDType::Q4K => Ok(Self::Q4K),
488            GgmlDType::Q5K => Ok(Self::Q5K),
489            GgmlDType::Q6K => Ok(Self::Q6K),
490            GgmlDType::Q4_0 => Ok(Self::Q4_0),
491            GgmlDType::Q4_1 => Ok(Self::Q4_1),
492            GgmlDType::Q5_0 => Ok(Self::Q5_0),
493            GgmlDType::Q5_1 => Ok(Self::Q5_1),
494            GgmlDType::Q8_0 => Ok(Self::Q8_0),
495            GgmlDType::Q8_1 => Ok(Self::Q8_1),
496            GgmlDType::Q8K => Ok(Self::Q8K),
497            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
498                candle_core::bail!("Expected valid GGML ISQ type.")
499            }
500        }
501    }
502}
503
504#[derive(Debug, Clone, Copy)]
505pub enum QuantizedSerdeType {
506    Gguf = 0,
507    Unquant = 1,
508    Hqq = 2,
509    Fp8 = 3,
510    Afq = 4,
511}
512
513impl TryFrom<usize> for QuantizedSerdeType {
514    type Error = candle_core::Error;
515    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
516        match value {
517            0 => Ok(Self::Gguf),
518            1 => Ok(Self::Unquant),
519            2 => Ok(Self::Hqq),
520            3 => Ok(Self::Fp8),
521            4 => Ok(Self::Afq),
522            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
523        }
524    }
525}
526
527pub trait QuantizedSerde {
528    fn name(&self) -> &'static str;
529    fn isq_serde_supported(&self) -> bool {
530        false
531    }
532    fn serialize(&self) -> Result<Cow<[u8]>> {
533        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
534    }
535    fn deserialize(
536        _data: Cow<[u8]>,
537        _device: &Device,
538        _comm: &Arc<crate::Comm>,
539        _guard: QuantizeOntoGuard,
540    ) -> Result<Arc<dyn QuantMethod>>
541    where
542        Self: Sized,
543    {
544        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
545    }
546    fn deserialize_ext_bias(
547        _data: Cow<[u8]>,
548        _device: &Device,
549        _guard: QuantizeOntoGuard,
550    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
551    where
552        Self: Sized,
553    {
554        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
555    }
556    /// NOT meant for external calling
557    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<[u8]>> {
558        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
559    }
560}
561
562/// Used to gate access to quantizing onto the host device
563#[derive(Clone, Debug)]
564#[allow(unused)]
565pub struct QuantizeOntoGuard {
566    pub inner: Arc<Mutex<()>>,
567}
568
569/// Real (for Metal) and Fake (for CUDA)
570pub enum QuantizeOntoDropGuard<'a> {
571    Real(MutexGuard<'a, ()>),
572    Fake,
573}
574
575impl Default for QuantizeOntoGuard {
576    fn default() -> Self {
577        Self::new()
578    }
579}
580
581impl QuantizeOntoGuard {
582    pub fn new() -> Self {
583        QuantizeOntoGuard {
584            inner: Arc::new(Mutex::new(())),
585        }
586    }
587
588    /// Acquire the quantize drop guard to protect the critical section.
589    ///
590    /// On metal, this flushes the command buffer to avoid "A command encoder is already encoding to this command buffer"
591    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
592        #[cfg(feature = "cuda")]
593        {
594            let _ = device;
595            QuantizeOntoDropGuard::Fake
596        }
597
598        #[cfg(not(feature = "cuda"))]
599        {
600            #[cfg(feature = "metal")]
601            if let Device::Metal(dev) = device {
602                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
603                dev.flush_command_buffer()
604                    .expect("Failed to flush command buffer.");
605            }
606            #[cfg(not(feature = "metal"))]
607            let _ = device;
608
609            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
610        }
611    }
612}
613
614pub enum DistributedKind {
615    ColumnParallel,
616    RowParallel,
617    Replicated,
618}
619
620/// Quantized method for a quantized matmul.
621pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
622    fn new(method: QuantMethodConfig) -> Result<Self>
623    where
624        Self: Sized;
625
626    fn dequantize_w(&self) -> Result<Tensor>;
627
628    /// Compute matmul of `self` and `a`. `self` should contain the weights.
629    /// Automatically cast to required quantization activation type and back
630    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
631        let original_ty = a.dtype();
632        let a = if let Some(t) = self.quantized_act_type() {
633            a.to_dtype(t)?
634        } else {
635            a.clone()
636        };
637        self.forward(&a)?.to_dtype(original_ty)
638    }
639
640    /// Compute matmul of `self` and `a`. `self` should contain the weights.
641    fn forward(&self, a: &Tensor) -> Result<Tensor>;
642
643    /// Compute matmul of `self` and `a`. `self` should contain the weights.
644    /// Automatically cast to required quantization activation type and back.
645    ///
646    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
647    /// then the indices are (n_tokens, n_experts).
648    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
649        let original_ty = a.dtype();
650        let a = if let Some(t) = self.quantized_act_type() {
651            a.to_dtype(t)?
652        } else {
653            a.clone()
654        };
655        self.gather_forward(&a, indices)?.to_dtype(original_ty)
656    }
657
658    /// Compute matmul of `self` and `a`. `self` should contain the weights.
659    ///
660    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
661    /// then the indices are (n_tokens, n_experts).
662    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
663        candle_core::bail!(
664            "{} does not support `gather_forward`. Please raise an issue.",
665            self.name()
666        )
667    }
668
669    /// If a quantized method, return the activation dtype.
670    fn quantized_act_type(&self) -> Option<DType>;
671
672    /// Weight dtype and device
673    fn dtype_and_device(&self) -> (DType, Device);
674
675    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
676    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
677
678    /// If the quant is backed by a qmatmul.
679    fn apply_isq(
680        self: Arc<Self>,
681        dtype: Option<IsqType>,
682        device: Device,
683        n_quantized: &AtomicUsize,
684        imatrix_weight: Option<Vec<f32>>,
685        guard: QuantizeOntoGuard,
686    ) -> Result<Arc<dyn QuantMethod>>;
687
688    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
689        None
690    }
691
692    /// Begin tracking stats into an ImatrixLayerStats
693    fn begin_track_stats(&mut self) -> Result<()> {
694        candle_core::bail!("`{}` does not support tracking stats.", self.name())
695    }
696
697    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
698    fn end_track_stats(&self) -> Result<Tensor> {
699        candle_core::bail!("`{}` does not support tracking stats.", self.name())
700    }
701
702    fn is_distributed(&self) -> Option<DistributedKind> {
703        None
704    }
705}
706
707impl Module for dyn QuantMethod {
708    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
709        Self::forward(self, xs)
710    }
711}
712
713pub fn linear_no_bias(
714    in_dim: usize,
715    out_dim: usize,
716    config: &Option<QuantizedConfig>,
717    vb: ShardedVarBuilder,
718) -> Result<Arc<dyn QuantMethod>> {
719    let base_vb = vb.clone();
720    let vb = if should_apply_immediate_isq(&vb) {
721        vb.set_device(Device::Cpu)
722    } else {
723        vb
724    };
725
726    let layer = if let Some(quant_conf) = &config {
727        match quant_conf {
728            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
729            QuantizedConfig::Fp8 { .. } => {
730                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
731            }
732            QuantizedConfig::Bitsandbytes { .. } => {
733                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
734            }
735            QuantizedConfig::Afq { .. } => {
736                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
737            }
738        }
739    } else {
740        // Handle the case where the layer is dummy (no tensors)
741        if !vb.contains_tensor("weight") {
742            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
743            Arc::new(layer) as Arc<dyn QuantMethod>
744        } else {
745            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
746            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
747
748            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
749                Linear::new(weight, None),
750            ))?;
751            Arc::new(layer) as Arc<dyn QuantMethod>
752        }
753    };
754    apply_immediate_isq(layer, base_vb)
755}
756
757pub fn linear(
758    in_dim: usize,
759    out_dim: usize,
760    config: &Option<QuantizedConfig>,
761    vb: ShardedVarBuilder,
762) -> Result<Arc<dyn QuantMethod>> {
763    let base_vb = vb.clone();
764    let vb = if should_apply_immediate_isq(&vb) {
765        vb.set_device(Device::Cpu)
766    } else {
767        vb
768    };
769
770    let layer = if let Some(quant_conf) = &config {
771        match quant_conf {
772            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
773            QuantizedConfig::Fp8 { .. } => {
774                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
775            }
776            QuantizedConfig::Bitsandbytes { .. } => {
777                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
778            }
779            QuantizedConfig::Afq { .. } => {
780                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
781            }
782        }
783    } else {
784        // Handle the case where the layer is dummy (no tensors)
785        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
786            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
787            Arc::new(layer) as Arc<dyn QuantMethod>
788        } else {
789            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
790            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
791            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
792
793            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
794                Linear::new(weight, Some(bias)),
795            ))?;
796            Arc::new(layer) as Arc<dyn QuantMethod>
797        }
798    };
799    apply_immediate_isq(layer, base_vb)
800}
801
802pub fn linear_b(
803    in_dim: usize,
804    out_dim: usize,
805    bias: bool,
806    config: &Option<QuantizedConfig>,
807    vb: ShardedVarBuilder,
808) -> Result<Arc<dyn QuantMethod>> {
809    if bias {
810        linear(in_dim, out_dim, config, vb)
811    } else {
812        linear_no_bias(in_dim, out_dim, config, vb)
813    }
814}