mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29pub mod rotary;
30pub mod safetensors;
31mod unquantized;
32mod utils;
33
34use gptq::gptq_linear;
35use lora::merge_lora_weights;
36pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
37
38pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
39pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
40pub use distributed::{
41    layers::{
42        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, ReplicatedLayer,
43        RowParallelLayer,
44    },
45    socket::{Client, Server},
46    BarrierLike, Comm, Id, SumAllReduce,
47};
48pub use dummy::DummyLayer;
49pub use fp8::FP8Linear;
50pub use gguf::GgufMatMul;
51pub use gptq::GptqLayer;
52pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
53pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
54pub use lora::{
55    linear_no_bias_static_lora, LoraAdapter, LoraConfig, StaticLoraConfig, APPLIED_LORAS,
56    MULTI_LORA_DELIMITER,
57};
58pub use unquantized::UnquantLinear;
59pub use utils::UQFF_QUANT_TYPE_OFFSET;
60
61use candle_nn::{Linear, Module};
62use serde::{Deserialize, Deserializer, Serialize};
63
64#[derive(Debug, Clone, Serialize)]
65#[serde(tag = "quant_method", rename_all = "lowercase")]
66pub enum QuantizedConfig {
67    Gptq {
68        bits: usize,
69        group_size: usize,
70        checkpoint_format: Option<String>,
71    },
72    Fp8 {
73        weight_block_size: Vec<usize>,
74    },
75    Bitsandbytes {
76        bnb_4bit_quant_type: Option<String>,
77    },
78    Afq {
79        bits: usize,
80        group_size: usize,
81    },
82}
83
84// Common fields for all variants
85#[derive(Deserialize)]
86struct RawConfig {
87    quant_method: Option<String>,
88    bits: Option<usize>,
89    group_size: Option<usize>,
90    checkpoint_format: Option<String>,
91    weight_block_size: Option<Vec<usize>>,
92    bnb_4bit_quant_type: Option<String>,
93}
94
95// Custom deserializer implementation
96impl<'de> Deserialize<'de> for QuantizedConfig {
97    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
98    where
99        D: Deserializer<'de>,
100    {
101        let raw = RawConfig::deserialize(deserializer)?;
102
103        match &raw.quant_method {
104            Some(m) if m == "gptq" => {
105                let bits = raw
106                    .bits
107                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
108                let group_size = raw
109                    .group_size
110                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
111                Ok(QuantizedConfig::Gptq {
112                    bits,
113                    group_size,
114                    checkpoint_format: raw.checkpoint_format,
115                })
116            }
117            Some(m) if m == "fp8" => {
118                let weight_block_size = raw
119                    .weight_block_size
120                    .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
121                Ok(QuantizedConfig::Fp8 { weight_block_size })
122            }
123            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
124                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
125            }),
126            Some(m) if m == "afq" => {
127                let bits = raw
128                    .bits
129                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
130                let group_size = raw
131                    .group_size
132                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
133                Ok(QuantizedConfig::Afq { bits, group_size })
134            }
135            None => {
136                let bits = raw
137                    .bits
138                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
139                let group_size = raw
140                    .group_size
141                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
142                Ok(QuantizedConfig::Afq { bits, group_size })
143            }
144            Some(unknown_method) => {
145                Err(serde::de::Error::custom(format!(
146                    "Unknown quantization method: {}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified", 
147                    unknown_method
148                )))
149            },
150        }
151    }
152}
153
154impl QuantizedConfig {
155    pub fn name(&self) -> &'static str {
156        match self {
157            Self::Gptq { .. } => "gptq",
158            Self::Fp8 { .. } => "fp8",
159            Self::Bitsandbytes { .. } => "bitsandbytes",
160            Self::Afq { .. } => "afq",
161        }
162    }
163
164    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
165        match self {
166            Self::Gptq { bits, .. } => format!("{bits} bits"),
167            Self::Fp8 { .. } => "8 bits".to_string(),
168            Self::Bitsandbytes {
169                bnb_4bit_quant_type: Some(_),
170            } => "4 bits".to_string(),
171            Self::Bitsandbytes {
172                bnb_4bit_quant_type: None,
173            } => "8 bits".to_string(),
174            Self::Afq { bits, .. } => format!("{bits} bits"),
175        }
176    }
177}
178
179#[derive(Debug, Clone)]
180pub enum QuantMethodConfig {
181    Gptq {
182        bits: i32,
183        use_exllama: bool,
184        q_weight: Tensor,
185        gptq_qzeros: Option<Tensor>,
186        gptq_scales: Tensor,
187        g_idx: Option<Tensor>,
188        bias: Option<Tensor>,
189        workspace: Option<Tensor>,
190        is_marlin: bool,
191    },
192    Gguf {
193        q_weight: Arc<QTensor>,
194        b: Option<Tensor>,
195    },
196    Unquantized(Linear),
197    Hqq {
198        tensor: Tensor,
199        bits: HqqBits,
200        group_size: NonZeroUsize,
201        axis: HqqAxis,
202        optimization_steps: Option<usize>,
203        round_zeros: Option<bool>,
204        channel_wise: Option<bool>,
205        bias: Option<Tensor>,
206    },
207    Dummy,
208    FP8 {
209        lin: Linear,
210        dtype: DType,
211    },
212    Bnb {
213        weight: Tensor,
214        bias: Option<Tensor>,
215        params: BnbQuantParmas,
216        quant_ty: BnbQuantType,
217    },
218    BlockwiseFP8 {
219        weight: Tensor,
220        weight_scale_inv: Tensor,
221        bias: Option<Tensor>,
222        dequant_dtype: DType,
223        weight_block_size: Vec<usize>,
224    },
225    Afq {
226        weight: Tensor,
227        bias: Option<Tensor>,
228        bits: AfqBits,
229        group_size: AfqGroupSize,
230    },
231}
232
233/// Device/configurable intelligent matrix multiplication
234/// - Handles limitation of `accelerate` which requires f32
235pub struct MatMul;
236
237impl MatMul {
238    /// Compute matrix-matrix product.
239    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
240        #[cfg(feature = "accelerate")]
241        {
242            let original_dtype = a.dtype();
243            a.to_dtype(DType::F32)?
244                .matmul(&b.to_dtype(DType::F32)?)?
245                .to_dtype(original_dtype)
246        }
247        #[cfg(not(feature = "accelerate"))]
248        {
249            if a.device().is_cpu() {
250                let original_dtype = a.dtype();
251                a.to_dtype(DType::F16)?
252                    .matmul(&b.to_dtype(DType::F16)?)?
253                    .to_dtype(original_dtype)
254            } else {
255                a.matmul(b)
256            }
257        }
258    }
259
260    /// Compute matrix-matrix product.
261    /// The result will be divided by the `scale` parameter in an affine division.
262    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
263        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
264        self.matmul(a, b)? / scale
265    }
266
267    /// Compute matrix-matrix product.
268    /// The result will be divided by the `scale` parameter in an affine multiplication.
269    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
270        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
271        self.matmul(a, b)? * scale
272    }
273
274    /// Compute quantized matrix-matrix product.
275    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
276        matmul.forward(x)
277    }
278
279    /// Compute quantized matrix-matrix product.
280    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
281        matmul.forward(x)
282    }
283}
284
285#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
286pub enum IsqType {
287    Q4_0,
288    Q4_1,
289    Q5_0,
290    Q5_1,
291    Q8_0,
292    Q8_1,
293    Q2K,
294    Q3K,
295    Q4K,
296    Q5K,
297    Q6K,
298    Q8K,
299    HQQ8,
300    HQQ4,
301    // HQQ3,
302    // HQQ2,
303    // HQQ1,
304    F8E4M3,
305    AFQ8,
306    AFQ6,
307    AFQ4,
308    AFQ3,
309    AFQ2,
310}
311
312impl IsqType {
313    /// Factor by which the weight size is reduced over the given dtype.
314    /// original size / pack factor = quantized size
315    pub fn pack_factor(&self, dtype: DType) -> usize {
316        match self {
317            Self::Q4_0 | Self::AFQ4 => {
318                (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size()) / GgmlDType::Q4_0.type_size()
319            }
320            Self::Q4_1 => {
321                (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size()) / GgmlDType::Q4_1.type_size()
322            }
323            Self::Q5_0 => {
324                (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size()) / GgmlDType::Q5_0.type_size()
325            }
326            Self::Q5_1 => {
327                (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size()) / GgmlDType::Q5_1.type_size()
328            }
329            Self::Q8_0 | Self::AFQ8 => {
330                (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size()) / GgmlDType::Q8_0.type_size()
331            }
332            Self::Q8_1 => {
333                (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size()) / GgmlDType::Q8_1.type_size()
334            }
335            Self::Q2K | Self::AFQ2 => {
336                (dtype.size_in_bytes() * GgmlDType::Q2K.block_size()) / GgmlDType::Q2K.type_size()
337            }
338            Self::Q3K | Self::AFQ3 => {
339                (dtype.size_in_bytes() * GgmlDType::Q3K.block_size()) / GgmlDType::Q3K.type_size()
340            }
341            Self::Q4K => {
342                (dtype.size_in_bytes() * GgmlDType::Q4K.block_size()) / GgmlDType::Q4K.type_size()
343            }
344            Self::Q5K => {
345                (dtype.size_in_bytes() * GgmlDType::Q5K.block_size()) / GgmlDType::Q5K.type_size()
346            }
347            Self::Q6K | Self::AFQ6 => {
348                (dtype.size_in_bytes() * GgmlDType::Q6K.block_size()) / GgmlDType::Q6K.type_size()
349            }
350            Self::Q8K => {
351                (dtype.size_in_bytes() * GgmlDType::Q8K.block_size()) / GgmlDType::Q8K.type_size()
352            }
353            // Estimates
354            Self::HQQ4 => 4,
355            Self::HQQ8 => 2,
356            Self::F8E4M3 => 2,
357        }
358    }
359
360    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
361        match self {
362            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
363            IsqType::HQQ4
364            | IsqType::HQQ8
365            | IsqType::AFQ2
366            | IsqType::AFQ3
367            | IsqType::AFQ4
368            | IsqType::AFQ6
369            | IsqType::AFQ8 => {
370                // Use 1 because our HQQ quantizes on the GPU
371                Some(1.try_into().unwrap())
372            }
373            IsqType::F8E4M3 => None,
374            IsqType::Q2K
375            | IsqType::Q3K
376            | IsqType::Q4K
377            | IsqType::Q4_0
378            | IsqType::Q4_1
379            | IsqType::Q5K
380            | IsqType::Q5_0
381            | IsqType::Q5_1
382            | IsqType::Q6K
383            | IsqType::Q8K
384            | IsqType::Q8_0
385            | IsqType::Q8_1 => None,
386        }
387    }
388}
389
390impl TryFrom<IsqType> for GgmlDType {
391    type Error = candle_core::Error;
392
393    fn try_from(value: IsqType) -> Result<Self> {
394        let tp = match value {
395            IsqType::Q2K => Self::Q2K,
396            IsqType::Q3K => Self::Q3K,
397            IsqType::Q4K => Self::Q4K,
398            IsqType::Q4_0 => Self::Q4_0,
399            IsqType::Q4_1 => Self::Q4_1,
400            IsqType::Q5K => Self::Q5K,
401            IsqType::Q5_0 => Self::Q5_0,
402            IsqType::Q5_1 => Self::Q5_1,
403            IsqType::Q6K => Self::Q6K,
404            IsqType::Q8K => Self::Q8K,
405            IsqType::Q8_0 => Self::Q8_0,
406            IsqType::Q8_1 => Self::Q8_1,
407            _ => candle_core::bail!("Expected valid GGML ISQ type."),
408        };
409        #[cfg(feature = "cuda")]
410        {
411            if !matches!(
412                tp,
413                GgmlDType::Q4_0
414                    | GgmlDType::Q4_1
415                    | GgmlDType::Q5_0
416                    | GgmlDType::Q5_1
417                    | GgmlDType::Q8_0
418                    | GgmlDType::Q2K
419                    | GgmlDType::Q3K
420                    | GgmlDType::Q4K
421                    | GgmlDType::Q5K
422                    | GgmlDType::Q6K
423            ) {
424                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
425            }
426        }
427        Ok(tp)
428    }
429}
430
431impl TryFrom<GgmlDType> for IsqType {
432    type Error = candle_core::Error;
433
434    fn try_from(value: GgmlDType) -> Result<Self> {
435        match value {
436            GgmlDType::Q2K => Ok(Self::Q2K),
437            GgmlDType::Q3K => Ok(Self::Q3K),
438            GgmlDType::Q4K => Ok(Self::Q4K),
439            GgmlDType::Q5K => Ok(Self::Q5K),
440            GgmlDType::Q6K => Ok(Self::Q6K),
441            GgmlDType::Q4_0 => Ok(Self::Q4_0),
442            GgmlDType::Q4_1 => Ok(Self::Q4_1),
443            GgmlDType::Q5_0 => Ok(Self::Q5_0),
444            GgmlDType::Q5_1 => Ok(Self::Q5_1),
445            GgmlDType::Q8_0 => Ok(Self::Q8_0),
446            GgmlDType::Q8_1 => Ok(Self::Q8_1),
447            GgmlDType::Q8K => Ok(Self::Q8K),
448            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
449                candle_core::bail!("Expected valid GGML ISQ type.")
450            }
451        }
452    }
453}
454
455pub enum QuantizedSerdeType {
456    Gguf = 0,
457    Unquant = 1,
458    Hqq = 2,
459    Fp8 = 3,
460    Afq = 4,
461}
462
463impl TryFrom<usize> for QuantizedSerdeType {
464    type Error = candle_core::Error;
465    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
466        match value {
467            0 => Ok(Self::Gguf),
468            1 => Ok(Self::Unquant),
469            2 => Ok(Self::Hqq),
470            3 => Ok(Self::Fp8),
471            4 => Ok(Self::Afq),
472            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
473        }
474    }
475}
476
477pub trait QuantizedSerde {
478    fn name(&self) -> &'static str;
479    fn isq_serde_supported(&self) -> bool {
480        false
481    }
482    fn serialize(&self) -> Result<Cow<[u8]>> {
483        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
484    }
485    fn deserialize(
486        _data: Cow<[u8]>,
487        _device: &Device,
488        _comm: &Arc<crate::Comm>,
489        _guard: QuantizeOntoGuard,
490    ) -> Result<Arc<dyn QuantMethod>>
491    where
492        Self: Sized,
493    {
494        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
495    }
496    fn deserialize_ext_bias(
497        _data: Cow<[u8]>,
498        _device: &Device,
499        _guard: QuantizeOntoGuard,
500    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
501    where
502        Self: Sized,
503    {
504        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
505    }
506    /// NOT meant for external calling
507    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<[u8]>> {
508        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
509    }
510}
511
512/// Used to gate access to quantizing onto the host device
513#[derive(Clone)]
514#[allow(unused)]
515pub struct QuantizeOntoGuard(Arc<Mutex<()>>);
516
517/// Real (for Metal) and Fake (for CUDA)
518pub enum QuantizeOntoDropGuard<'a> {
519    Real(MutexGuard<'a, ()>),
520    Fake,
521}
522
523impl Default for QuantizeOntoGuard {
524    fn default() -> Self {
525        Self::new()
526    }
527}
528
529impl QuantizeOntoGuard {
530    pub fn new() -> Self {
531        Self(Arc::new(Mutex::new(())))
532    }
533
534    pub fn acquire(&self) -> QuantizeOntoDropGuard<'_> {
535        #[cfg(feature = "cuda")]
536        {
537            QuantizeOntoDropGuard::Fake
538        }
539
540        #[cfg(not(feature = "cuda"))]
541        {
542            QuantizeOntoDropGuard::Real(self.0.lock().expect("QuantizeOntoGuard was poisoned!"))
543        }
544    }
545}
546
547pub enum DistributedKind {
548    ColumnParallel,
549    RowParallel,
550    Replicated,
551}
552
553/// Quantized method for a quantized matmul.
554pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
555    fn new(method: QuantMethodConfig) -> Result<Self>
556    where
557        Self: Sized;
558
559    fn dequantize_w(&self) -> Result<Tensor>;
560
561    /// Compute matmul of `self` and `a`. `self` should contain the weights.
562    /// Automatically cast to required quantization activation type and back
563    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
564        let original_ty = a.dtype();
565        let a = if let Some(t) = self.quantized_act_type() {
566            a.to_dtype(t)?
567        } else {
568            a.clone()
569        };
570        self.forward(&a)?.to_dtype(original_ty)
571    }
572
573    /// Compute matmul of `self` and `a`. `self` should contain the weights.
574    fn forward(&self, a: &Tensor) -> Result<Tensor>;
575
576    /// Compute matmul of `self` and `a`. `self` should contain the weights.
577    /// Automatically cast to required quantization activation type and back.
578    ///
579    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
580    /// then the indices are (n_tokens, n_experts).
581    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
582        let original_ty = a.dtype();
583        let a = if let Some(t) = self.quantized_act_type() {
584            a.to_dtype(t)?
585        } else {
586            a.clone()
587        };
588        self.gather_forward(&a, indices)?.to_dtype(original_ty)
589    }
590
591    /// Compute matmul of `self` and `a`. `self` should contain the weights.
592    ///
593    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
594    /// then the indices are (n_tokens, n_experts).
595    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
596        candle_core::bail!(
597            "{} does not support `gather_forward`. Please raise an issue.",
598            self.name()
599        )
600    }
601
602    /// If a quantized method, return the activation dtype.
603    fn quantized_act_type(&self) -> Option<DType>;
604
605    /// Weight dtype and device
606    fn dtype_and_device(&self) -> (DType, Device);
607
608    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
609    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
610
611    /// If the quant is backed by a qmatmul.
612    fn apply_isq(
613        self: Arc<Self>,
614        dtype: Option<IsqType>,
615        device: Device,
616        n_quantized: &AtomicUsize,
617        imatrix_weight: Option<Vec<f32>>,
618        guard: QuantizeOntoGuard,
619    ) -> Result<Arc<dyn QuantMethod>>;
620
621    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
622        None
623    }
624
625    /// Begin tracking stats into an ImatrixLayerStats
626    fn begin_track_stats(&mut self) -> Result<()> {
627        candle_core::bail!("`{}` does not support tracking stats.", self.name())
628    }
629
630    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
631    fn end_track_stats(&self) -> Result<Tensor> {
632        candle_core::bail!("`{}` does not support tracking stats.", self.name())
633    }
634
635    fn is_distributed(&self) -> Option<DistributedKind> {
636        None
637    }
638}
639
640impl Module for dyn QuantMethod {
641    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
642        Self::forward(self, xs)
643    }
644}
645
646pub fn linear_no_bias(
647    in_dim: usize,
648    out_dim: usize,
649    config: &Option<QuantizedConfig>,
650    vb: ShardedVarBuilder,
651) -> Result<Arc<dyn QuantMethod>> {
652    let layer = if let Some(quant_conf) = &config {
653        match quant_conf {
654            QuantizedConfig::Gptq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
655            QuantizedConfig::Fp8 { .. } => {
656                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
657            }
658            QuantizedConfig::Bitsandbytes { .. } => {
659                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
660            }
661            QuantizedConfig::Afq { .. } => {
662                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
663            }
664        }
665    } else {
666        // Handle the case where the layer is dummy (no tensors)
667        if !vb.contains_tensor("weight") {
668            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
669            Arc::new(layer) as Arc<dyn QuantMethod>
670        } else {
671            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
672            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
673
674            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
675                Linear::new(weight, None),
676            ))?;
677            Arc::new(layer) as Arc<dyn QuantMethod>
678        }
679    };
680    Ok(layer)
681}
682
683pub fn linear(
684    in_dim: usize,
685    out_dim: usize,
686    config: &Option<QuantizedConfig>,
687    vb: ShardedVarBuilder,
688) -> Result<Arc<dyn QuantMethod>> {
689    let layer = if let Some(quant_conf) = &config {
690        match quant_conf {
691            QuantizedConfig::Gptq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
692            QuantizedConfig::Fp8 { .. } => {
693                blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
694            }
695            QuantizedConfig::Bitsandbytes { .. } => {
696                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
697            }
698            QuantizedConfig::Afq { .. } => {
699                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
700            }
701        }
702    } else {
703        // Handle the case where the layer is dummy (no tensors)
704        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
705            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
706            Arc::new(layer) as Arc<dyn QuantMethod>
707        } else {
708            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
709            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
710            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
711
712            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
713                Linear::new(weight, Some(bias)),
714            ))?;
715            Arc::new(layer) as Arc<dyn QuantMethod>
716        }
717    };
718    Ok(layer)
719}
720
721pub fn linear_b(
722    in_dim: usize,
723    out_dim: usize,
724    bias: bool,
725    config: &Option<QuantizedConfig>,
726    vb: ShardedVarBuilder,
727) -> Result<Arc<dyn QuantMethod>> {
728    if bias {
729        linear(in_dim, out_dim, config, vb)
730    } else {
731        linear_no_bias(in_dim, out_dim, config, vb)
732    }
733}