mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29pub mod rotary;
30pub mod safetensors;
31mod unquantized;
32mod utils;
33
34use gptq::gptq_linear;
35use lora::merge_lora_weights;
36use regex::Regex;
37pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
38
39pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
40pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
41pub use distributed::{
42    layers::{
43        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
44        ReplicatedLayer, RowParallelLayer,
45    },
46    socket::{Client, Server},
47    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
48};
49pub use dummy::DummyLayer;
50pub use fp8::FP8Linear;
51pub use gguf::GgufMatMul;
52pub use gptq::GptqLayer;
53pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
54pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
55pub use lora::{
56    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
57    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
58};
59pub use unquantized::UnquantLinear;
60pub use utils::isq::apply_immediate_isq;
61pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
62
63use candle_nn::{Linear, Module};
64use serde::{Deserialize, Deserializer, Serialize};
65
66#[derive(Clone, Debug)]
67pub struct ImmediateIsqParams {
68    pub guard: QuantizeOntoGuard,
69    pub ty: Option<IsqType>,
70    pub predicates: Vec<Regex>,
71}
72
73thread_local! {
74    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
75}
76
77pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
78    ENGINE_IMMEDIATE_ISQ.with(|cell| {
79        *cell.borrow_mut() = Some(ImmediateIsqParams {
80            guard: QuantizeOntoGuard::new(),
81            ty: isq,
82            predicates,
83        });
84    });
85}
86
87pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
88    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
89}
90
91pub fn clear_immediate_isq() {
92    ENGINE_IMMEDIATE_ISQ.with(|cell| {
93        *cell.borrow_mut() = None;
94    });
95}
96
97pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
98    let Some(immediate_isq) = get_immediate_isq() else {
99        return false;
100    };
101    // Add a .weight to match the ISQ regexes!
102    let prefix = format!("{}.weight", vb.prefix());
103    immediate_isq.ty.is_some()
104        && immediate_isq
105            .predicates
106            .iter()
107            .any(|predicate| predicate.is_match(&prefix))
108}
109
110#[derive(Debug, Clone, Serialize)]
111#[serde(tag = "quant_method", rename_all = "lowercase")]
112pub enum QuantizedConfig {
113    GptqAwq {
114        bits: usize,
115        group_size: usize,
116        checkpoint_format: Option<String>,
117        is_awq: bool,
118    },
119    Fp8 {
120        weight_block_size: Vec<usize>,
121    },
122    Bitsandbytes {
123        bnb_4bit_quant_type: Option<String>,
124    },
125    Afq {
126        bits: usize,
127        group_size: usize,
128    },
129}
130
131// Common fields for all variants
132#[derive(Deserialize)]
133struct RawConfig {
134    quant_method: Option<String>,
135    bits: Option<usize>,
136    group_size: Option<usize>,
137    checkpoint_format: Option<String>,
138    weight_block_size: Option<Vec<usize>>,
139    bnb_4bit_quant_type: Option<String>,
140}
141
142// Custom deserializer implementation
143impl<'de> Deserialize<'de> for QuantizedConfig {
144    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
145    where
146        D: Deserializer<'de>,
147    {
148        let raw = RawConfig::deserialize(deserializer)?;
149
150        match &raw.quant_method {
151            Some(m) if m == "gptq" || m == "awq" => {
152                let bits = raw
153                    .bits
154                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
155                let group_size = raw
156                    .group_size
157                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
158                Ok(QuantizedConfig::GptqAwq {
159                    bits,
160                    group_size,
161                    checkpoint_format: raw.checkpoint_format,
162                    is_awq: m == "awq",
163                })
164            }
165            Some(m) if m == "fp8" => {
166                let weight_block_size = raw
167                    .weight_block_size
168                    .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
169                Ok(QuantizedConfig::Fp8 { weight_block_size })
170            }
171            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
172                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
173            }),
174            Some(m) if m == "afq" => {
175                let bits = raw
176                    .bits
177                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
178                let group_size = raw
179                    .group_size
180                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
181                Ok(QuantizedConfig::Afq { bits, group_size })
182            }
183            None => {
184                let bits = raw
185                    .bits
186                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
187                let group_size = raw
188                    .group_size
189                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
190                Ok(QuantizedConfig::Afq { bits, group_size })
191            }
192            Some(unknown_method) => {
193                Err(serde::de::Error::custom(format!(
194                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
195                )))
196            },
197        }
198    }
199}
200
201impl QuantizedConfig {
202    pub fn name(&self) -> &'static str {
203        match self {
204            Self::GptqAwq { .. } => "gptq",
205            Self::Fp8 { .. } => "fp8",
206            Self::Bitsandbytes { .. } => "bitsandbytes",
207            Self::Afq { .. } => "afq",
208        }
209    }
210
211    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
212        match self {
213            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
214            Self::Fp8 { .. } => "8 bits".to_string(),
215            Self::Bitsandbytes {
216                bnb_4bit_quant_type: Some(_),
217            } => "4 bits".to_string(),
218            Self::Bitsandbytes {
219                bnb_4bit_quant_type: None,
220            } => "8 bits".to_string(),
221            Self::Afq { bits, .. } => format!("{bits} bits"),
222        }
223    }
224
225    pub fn pack_factor(&self, dtype: DType) -> usize {
226        match self {
227            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
228                2 => IsqType::Q2K.pack_factor(dtype),
229                3 => IsqType::Q3K.pack_factor(dtype),
230                4 => IsqType::Q4K.pack_factor(dtype),
231                5 => IsqType::Q5K.pack_factor(dtype),
232                6 => IsqType::Q6K.pack_factor(dtype),
233                8 => IsqType::Q8_0.pack_factor(dtype),
234                other => panic!("Unexpected bits in `pack_factor` {other}"),
235            },
236            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
237            Self::Bitsandbytes {
238                bnb_4bit_quant_type: Some(_),
239            }
240            | Self::Bitsandbytes {
241                bnb_4bit_quant_type: None,
242            } => IsqType::Q4K.pack_factor(dtype),
243        }
244    }
245}
246
247#[derive(Debug, Clone)]
248pub enum QuantMethodConfig {
249    GptqAwq {
250        bits: i32,
251        use_exllama: bool,
252        q_weight: Tensor,
253        qzeros: Option<Tensor>,
254        scales: Tensor,
255        g_idx: Option<Tensor>,
256        bias: Option<Tensor>,
257        workspace: Option<Tensor>,
258        is_marlin: bool,
259        is_awq: bool,
260    },
261    Gguf {
262        q_weight: Arc<QTensor>,
263        b: Option<Tensor>,
264    },
265    Unquantized(Linear),
266    Hqq {
267        tensor: Tensor,
268        bits: HqqBits,
269        group_size: NonZeroUsize,
270        axis: HqqAxis,
271        optimization_steps: Option<usize>,
272        round_zeros: Option<bool>,
273        channel_wise: Option<bool>,
274        bias: Option<Tensor>,
275    },
276    Dummy,
277    FP8 {
278        lin: Linear,
279        dtype: DType,
280    },
281    Bnb {
282        weight: Tensor,
283        bias: Option<Tensor>,
284        params: BnbQuantParmas,
285        quant_ty: BnbQuantType,
286    },
287    BlockwiseFP8 {
288        weight: Tensor,
289        weight_scale_inv: Tensor,
290        bias: Option<Tensor>,
291        dequant_dtype: DType,
292        weight_block_size: Vec<usize>,
293    },
294    Afq {
295        weight: Tensor,
296        bias: Option<Tensor>,
297        bits: AfqBits,
298        group_size: AfqGroupSize,
299    },
300}
301
302/// Device/configurable intelligent matrix multiplication
303/// - Handles limitation of `accelerate` which requires f32
304pub struct MatMul;
305
306impl MatMul {
307    /// Compute matrix-matrix product.
308    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
309        #[cfg(feature = "accelerate")]
310        {
311            let original_dtype = a.dtype();
312            a.to_dtype(DType::F32)?
313                .matmul(&b.to_dtype(DType::F32)?)?
314                .to_dtype(original_dtype)
315        }
316        #[cfg(not(feature = "accelerate"))]
317        {
318            if a.device().is_cpu() {
319                let original_dtype = a.dtype();
320                a.to_dtype(DType::F16)?
321                    .matmul(&b.to_dtype(DType::F16)?)?
322                    .to_dtype(original_dtype)
323            } else {
324                a.matmul(b)
325            }
326        }
327    }
328
329    /// Compute matrix-matrix product.
330    /// The result will be divided by the `scale` parameter in an affine division.
331    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
332        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
333        self.matmul(a, b)? / scale
334    }
335
336    /// Compute matrix-matrix product.
337    /// The result will be divided by the `scale` parameter in an affine multiplication.
338    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
339        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
340        self.matmul(a, b)? * scale
341    }
342
343    /// Compute quantized matrix-matrix product.
344    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
345        matmul.forward(x)
346    }
347
348    /// Compute quantized matrix-matrix product.
349    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
350        matmul.forward(x)
351    }
352}
353
354#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
355pub enum IsqType {
356    Q4_0,
357    Q4_1,
358    Q5_0,
359    Q5_1,
360    Q8_0,
361    Q8_1,
362    Q2K,
363    Q3K,
364    Q4K,
365    Q5K,
366    Q6K,
367    Q8K,
368    HQQ8,
369    HQQ4,
370    // HQQ3,
371    // HQQ2,
372    // HQQ1,
373    F8E4M3,
374    AFQ8,
375    AFQ6,
376    AFQ4,
377    AFQ3,
378    AFQ2,
379}
380
381impl IsqType {
382    /// Factor by which the weight size is reduced over the given dtype.
383    /// original size / pack factor = quantized size
384    pub fn pack_factor(&self, dtype: DType) -> usize {
385        match self {
386            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
387                .div_ceil(GgmlDType::Q4_0.type_size()),
388            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
389                .div_ceil(GgmlDType::Q4_1.type_size()),
390            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
391                .div_ceil(GgmlDType::Q5_0.type_size()),
392            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
393                .div_ceil(GgmlDType::Q5_1.type_size()),
394            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
395                .div_ceil(GgmlDType::Q8_0.type_size()),
396            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
397                .div_ceil(GgmlDType::Q8_1.type_size()),
398            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
399                .div_ceil(GgmlDType::Q2K.type_size()),
400            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
401                .div_ceil(GgmlDType::Q3K.type_size()),
402            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
403                .div_ceil(GgmlDType::Q4K.type_size()),
404            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
405                .div_ceil(GgmlDType::Q5K.type_size()),
406            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
407                .div_ceil(GgmlDType::Q6K.type_size()),
408            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
409                .div_ceil(GgmlDType::Q8K.type_size()),
410            // Estimates
411            Self::HQQ4 => 4,
412            Self::HQQ8 => 2,
413            Self::F8E4M3 => 2,
414        }
415    }
416
417    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
418        match self {
419            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
420            IsqType::HQQ4
421            | IsqType::HQQ8
422            | IsqType::AFQ2
423            | IsqType::AFQ3
424            | IsqType::AFQ4
425            | IsqType::AFQ6
426            | IsqType::AFQ8 => {
427                // Use 1 because our HQQ quantizes on the GPU
428                Some(1.try_into().unwrap())
429            }
430            IsqType::F8E4M3 => None,
431            IsqType::Q2K
432            | IsqType::Q3K
433            | IsqType::Q4K
434            | IsqType::Q4_0
435            | IsqType::Q4_1
436            | IsqType::Q5K
437            | IsqType::Q5_0
438            | IsqType::Q5_1
439            | IsqType::Q6K
440            | IsqType::Q8K
441            | IsqType::Q8_0
442            | IsqType::Q8_1 => None,
443        }
444    }
445}
446
447impl TryFrom<IsqType> for GgmlDType {
448    type Error = candle_core::Error;
449
450    fn try_from(value: IsqType) -> Result<Self> {
451        let tp = match value {
452            IsqType::Q2K => Self::Q2K,
453            IsqType::Q3K => Self::Q3K,
454            IsqType::Q4K => Self::Q4K,
455            IsqType::Q4_0 => Self::Q4_0,
456            IsqType::Q4_1 => Self::Q4_1,
457            IsqType::Q5K => Self::Q5K,
458            IsqType::Q5_0 => Self::Q5_0,
459            IsqType::Q5_1 => Self::Q5_1,
460            IsqType::Q6K => Self::Q6K,
461            IsqType::Q8K => Self::Q8K,
462            IsqType::Q8_0 => Self::Q8_0,
463            IsqType::Q8_1 => Self::Q8_1,
464            _ => candle_core::bail!("Expected valid GGML ISQ type."),
465        };
466        #[cfg(feature = "cuda")]
467        {
468            if !matches!(
469                tp,
470                GgmlDType::Q4_0
471                    | GgmlDType::Q4_1
472                    | GgmlDType::Q5_0
473                    | GgmlDType::Q5_1
474                    | GgmlDType::Q8_0
475                    | GgmlDType::Q2K
476                    | GgmlDType::Q3K
477                    | GgmlDType::Q4K
478                    | GgmlDType::Q5K
479                    | GgmlDType::Q6K
480            ) {
481                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
482            }
483        }
484        Ok(tp)
485    }
486}
487
488impl TryFrom<GgmlDType> for IsqType {
489    type Error = candle_core::Error;
490
491    fn try_from(value: GgmlDType) -> Result<Self> {
492        match value {
493            GgmlDType::Q2K => Ok(Self::Q2K),
494            GgmlDType::Q3K => Ok(Self::Q3K),
495            GgmlDType::Q4K => Ok(Self::Q4K),
496            GgmlDType::Q5K => Ok(Self::Q5K),
497            GgmlDType::Q6K => Ok(Self::Q6K),
498            GgmlDType::Q4_0 => Ok(Self::Q4_0),
499            GgmlDType::Q4_1 => Ok(Self::Q4_1),
500            GgmlDType::Q5_0 => Ok(Self::Q5_0),
501            GgmlDType::Q5_1 => Ok(Self::Q5_1),
502            GgmlDType::Q8_0 => Ok(Self::Q8_0),
503            GgmlDType::Q8_1 => Ok(Self::Q8_1),
504            GgmlDType::Q8K => Ok(Self::Q8K),
505            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
506                candle_core::bail!("Expected valid GGML ISQ type.")
507            }
508        }
509    }
510}
511
512#[derive(Debug, Clone, Copy)]
513pub enum QuantizedSerdeType {
514    Gguf = 0,
515    Unquant = 1,
516    Hqq = 2,
517    Fp8 = 3,
518    Afq = 4,
519}
520
521impl TryFrom<usize> for QuantizedSerdeType {
522    type Error = candle_core::Error;
523    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
524        match value {
525            0 => Ok(Self::Gguf),
526            1 => Ok(Self::Unquant),
527            2 => Ok(Self::Hqq),
528            3 => Ok(Self::Fp8),
529            4 => Ok(Self::Afq),
530            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
531        }
532    }
533}
534
535pub trait QuantizedSerde {
536    fn name(&self) -> &'static str;
537    fn isq_serde_supported(&self) -> bool {
538        false
539    }
540    fn serialize(&self) -> Result<Cow<[u8]>> {
541        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
542    }
543    fn deserialize(
544        _data: Cow<[u8]>,
545        _device: &Device,
546        _comm: &Arc<crate::Comm>,
547        _guard: QuantizeOntoGuard,
548    ) -> Result<Arc<dyn QuantMethod>>
549    where
550        Self: Sized,
551    {
552        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
553    }
554    fn deserialize_ext_bias(
555        _data: Cow<[u8]>,
556        _device: &Device,
557        _guard: QuantizeOntoGuard,
558    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
559    where
560        Self: Sized,
561    {
562        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
563    }
564    /// NOT meant for external calling
565    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<[u8]>> {
566        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
567    }
568}
569
570/// Used to gate access to quantizing onto the host device
571#[derive(Clone, Debug)]
572#[allow(unused)]
573pub struct QuantizeOntoGuard {
574    pub inner: Arc<Mutex<()>>,
575}
576
577/// Real (for Metal) and Fake (for CUDA)
578pub enum QuantizeOntoDropGuard<'a> {
579    Real(MutexGuard<'a, ()>),
580    Fake,
581}
582
583impl Default for QuantizeOntoGuard {
584    fn default() -> Self {
585        Self::new()
586    }
587}
588
589impl QuantizeOntoGuard {
590    pub fn new() -> Self {
591        QuantizeOntoGuard {
592            inner: Arc::new(Mutex::new(())),
593        }
594    }
595
596    /// Acquire the quantize drop guard to protect the critical section.
597    ///
598    /// On metal, this flushes the command buffer to avoid "A command encoder is already encoding to this command buffer"
599    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
600        #[cfg(feature = "cuda")]
601        {
602            let _ = device;
603            QuantizeOntoDropGuard::Fake
604        }
605
606        #[cfg(not(feature = "cuda"))]
607        {
608            #[cfg(feature = "metal")]
609            if let Device::Metal(dev) = device {
610                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
611                dev.flush_command_buffer()
612                    .expect("Failed to flush command buffer.");
613            }
614            #[cfg(not(feature = "metal"))]
615            let _ = device;
616
617            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
618        }
619    }
620}
621
622pub enum DistributedKind {
623    ColumnParallel,
624    RowParallel,
625    Replicated,
626}
627
628/// Quantized method for a quantized matmul.
629pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
630    fn new(method: QuantMethodConfig) -> Result<Self>
631    where
632        Self: Sized;
633
634    fn dequantize_w(&self) -> Result<Tensor>;
635
636    /// Compute matmul of `self` and `a`. `self` should contain the weights.
637    /// Automatically cast to required quantization activation type and back
638    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
639        let original_ty = a.dtype();
640        let a = if let Some(t) = self.quantized_act_type() {
641            a.to_dtype(t)?
642        } else {
643            a.clone()
644        };
645        self.forward(&a)?.to_dtype(original_ty)
646    }
647
648    /// Compute matmul of `self` and `a`. `self` should contain the weights.
649    fn forward(&self, a: &Tensor) -> Result<Tensor>;
650
651    /// Compute matmul of `self` and `a`. `self` should contain the weights.
652    /// Automatically cast to required quantization activation type and back.
653    ///
654    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
655    /// then the indices are (n_tokens, n_experts).
656    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
657        let original_ty = a.dtype();
658        let a = if let Some(t) = self.quantized_act_type() {
659            a.to_dtype(t)?
660        } else {
661            a.clone()
662        };
663        self.gather_forward(&a, indices)?.to_dtype(original_ty)
664    }
665
666    /// Compute matmul of `self` and `a`. `self` should contain the weights.
667    ///
668    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
669    /// then the indices are (n_tokens, n_experts).
670    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
671        candle_core::bail!(
672            "{} does not support `gather_forward`. Please raise an issue.",
673            self.name()
674        )
675    }
676
677    /// If a quantized method, return the activation dtype.
678    fn quantized_act_type(&self) -> Option<DType>;
679
680    /// Weight dtype and device
681    fn dtype_and_device(&self) -> (DType, Device);
682
683    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
684    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
685
686    /// If the quant is backed by a qmatmul.
687    fn apply_isq(
688        self: Arc<Self>,
689        dtype: Option<IsqType>,
690        device: Device,
691        n_quantized: &AtomicUsize,
692        imatrix_weight: Option<Vec<f32>>,
693        guard: QuantizeOntoGuard,
694    ) -> Result<Arc<dyn QuantMethod>>;
695
696    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
697        None
698    }
699
700    /// Begin tracking stats into an ImatrixLayerStats
701    fn begin_track_stats(&mut self) -> Result<()> {
702        candle_core::bail!("`{}` does not support tracking stats.", self.name())
703    }
704
705    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
706    fn end_track_stats(&self) -> Result<Tensor> {
707        candle_core::bail!("`{}` does not support tracking stats.", self.name())
708    }
709
710    fn is_distributed(&self) -> Option<DistributedKind> {
711        None
712    }
713}
714
715impl Module for dyn QuantMethod {
716    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
717        Self::forward(self, xs)
718    }
719}
720
721pub fn linear_no_bias(
722    in_dim: usize,
723    out_dim: usize,
724    config: &Option<QuantizedConfig>,
725    vb: ShardedVarBuilder,
726) -> Result<Arc<dyn QuantMethod>> {
727    let base_vb = vb.clone();
728    let vb = if should_apply_immediate_isq(&vb) {
729        vb.set_device(Device::Cpu)
730    } else {
731        vb
732    };
733
734    let layer = if let Some(quant_conf) = &config {
735        match quant_conf {
736            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
737            QuantizedConfig::Fp8 { .. } => {
738                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
739            }
740            QuantizedConfig::Bitsandbytes { .. } => {
741                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
742            }
743            QuantizedConfig::Afq { .. } => {
744                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
745            }
746        }
747    } else {
748        // Handle the case where the layer is dummy (no tensors)
749        if !vb.contains_tensor("weight") {
750            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
751            Arc::new(layer) as Arc<dyn QuantMethod>
752        } else {
753            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
754            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
755
756            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
757                Linear::new(weight, None),
758            ))?;
759            Arc::new(layer) as Arc<dyn QuantMethod>
760        }
761    };
762    apply_immediate_isq(layer, base_vb)
763}
764
765pub fn linear(
766    in_dim: usize,
767    out_dim: usize,
768    config: &Option<QuantizedConfig>,
769    vb: ShardedVarBuilder,
770) -> Result<Arc<dyn QuantMethod>> {
771    let base_vb = vb.clone();
772    let vb = if should_apply_immediate_isq(&vb) {
773        vb.set_device(Device::Cpu)
774    } else {
775        vb
776    };
777
778    let layer = if let Some(quant_conf) = &config {
779        match quant_conf {
780            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
781            QuantizedConfig::Fp8 { .. } => {
782                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
783            }
784            QuantizedConfig::Bitsandbytes { .. } => {
785                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
786            }
787            QuantizedConfig::Afq { .. } => {
788                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
789            }
790        }
791    } else {
792        // Handle the case where the layer is dummy (no tensors)
793        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
794            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
795            Arc::new(layer) as Arc<dyn QuantMethod>
796        } else {
797            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
798            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
799            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
800
801            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
802                Linear::new(weight, Some(bias)),
803            ))?;
804            Arc::new(layer) as Arc<dyn QuantMethod>
805        }
806    };
807    apply_immediate_isq(layer, base_vb)
808}
809
810pub fn linear_b(
811    in_dim: usize,
812    out_dim: usize,
813    bias: bool,
814    config: &Option<QuantizedConfig>,
815    vb: ShardedVarBuilder,
816) -> Result<Arc<dyn QuantMethod>> {
817    if bias {
818        linear(in_dim, out_dim, config, vb)
819    } else {
820        linear_no_bias(in_dim, out_dim, config, vb)
821    }
822}