1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29pub mod rotary;
30pub mod safetensors;
31mod unquantized;
32mod utils;
33
34use gptq::gptq_linear;
35use lora::merge_lora_weights;
36use regex::Regex;
37pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
38
39pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
40pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
41pub use distributed::{
42 layers::{
43 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
44 ReplicatedLayer, RowParallelLayer,
45 },
46 socket::{Client, Server},
47 BarrierLike, Comm, Id, SumAllReduce,
48};
49pub use dummy::DummyLayer;
50pub use fp8::FP8Linear;
51pub use gguf::GgufMatMul;
52pub use gptq::GptqLayer;
53pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
54pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
55pub use lora::{
56 linear_no_bias_static_lora, LoraAdapter, LoraConfig, StaticLoraConfig, APPLIED_LORAS,
57 MULTI_LORA_DELIMITER,
58};
59pub use unquantized::UnquantLinear;
60pub use utils::isq::apply_immediate_isq;
61pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
62
63use candle_nn::{Linear, Module};
64use serde::{Deserialize, Deserializer, Serialize};
65
66#[derive(Clone, Debug)]
67pub struct ImmediateIsqParams {
68 pub guard: QuantizeOntoGuard,
69 pub ty: Option<IsqType>,
70 pub predicates: Vec<Regex>,
71}
72
73static IMMEDIATE_ISQ: Mutex<Option<ImmediateIsqParams>> = Mutex::new(None);
74
75pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
76 let mut guard = IMMEDIATE_ISQ.lock().expect("IMMEDIATE_ISQ mutex poisoned");
77 *guard = Some(ImmediateIsqParams {
78 guard: QuantizeOntoGuard::new(),
79 ty: isq,
80 predicates,
81 });
82}
83
84pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
85 IMMEDIATE_ISQ.lock().ok().and_then(|guard| guard.clone())
86}
87
88pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
89 let Some(immediate_isq) = get_immediate_isq() else {
90 return false;
91 };
92 let prefix = format!("{}.weight", vb.prefix());
94 immediate_isq.ty.is_some()
95 && immediate_isq
96 .predicates
97 .iter()
98 .any(|predicate| predicate.is_match(&prefix))
99}
100
101#[derive(Debug, Clone, Serialize)]
102#[serde(tag = "quant_method", rename_all = "lowercase")]
103pub enum QuantizedConfig {
104 GptqAwq {
105 bits: usize,
106 group_size: usize,
107 checkpoint_format: Option<String>,
108 is_awq: bool,
109 },
110 Fp8 {
111 weight_block_size: Vec<usize>,
112 },
113 Bitsandbytes {
114 bnb_4bit_quant_type: Option<String>,
115 },
116 Afq {
117 bits: usize,
118 group_size: usize,
119 },
120}
121
122#[derive(Deserialize)]
124struct RawConfig {
125 quant_method: Option<String>,
126 bits: Option<usize>,
127 group_size: Option<usize>,
128 checkpoint_format: Option<String>,
129 weight_block_size: Option<Vec<usize>>,
130 bnb_4bit_quant_type: Option<String>,
131}
132
133impl<'de> Deserialize<'de> for QuantizedConfig {
135 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
136 where
137 D: Deserializer<'de>,
138 {
139 let raw = RawConfig::deserialize(deserializer)?;
140
141 match &raw.quant_method {
142 Some(m) if m == "gptq" || m == "awq" => {
143 let bits = raw
144 .bits
145 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
146 let group_size = raw
147 .group_size
148 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
149 Ok(QuantizedConfig::GptqAwq {
150 bits,
151 group_size,
152 checkpoint_format: raw.checkpoint_format,
153 is_awq: m == "awq",
154 })
155 }
156 Some(m) if m == "fp8" => {
157 let weight_block_size = raw
158 .weight_block_size
159 .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
160 Ok(QuantizedConfig::Fp8 { weight_block_size })
161 }
162 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
163 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
164 }),
165 Some(m) if m == "afq" => {
166 let bits = raw
167 .bits
168 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
169 let group_size = raw
170 .group_size
171 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
172 Ok(QuantizedConfig::Afq { bits, group_size })
173 }
174 None => {
175 let bits = raw
176 .bits
177 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
178 let group_size = raw
179 .group_size
180 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
181 Ok(QuantizedConfig::Afq { bits, group_size })
182 }
183 Some(unknown_method) => {
184 Err(serde::de::Error::custom(format!(
185 "Unknown quantization method: {}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified",
186 unknown_method
187 )))
188 },
189 }
190 }
191}
192
193impl QuantizedConfig {
194 pub fn name(&self) -> &'static str {
195 match self {
196 Self::GptqAwq { .. } => "gptq",
197 Self::Fp8 { .. } => "fp8",
198 Self::Bitsandbytes { .. } => "bitsandbytes",
199 Self::Afq { .. } => "afq",
200 }
201 }
202
203 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
204 match self {
205 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
206 Self::Fp8 { .. } => "8 bits".to_string(),
207 Self::Bitsandbytes {
208 bnb_4bit_quant_type: Some(_),
209 } => "4 bits".to_string(),
210 Self::Bitsandbytes {
211 bnb_4bit_quant_type: None,
212 } => "8 bits".to_string(),
213 Self::Afq { bits, .. } => format!("{bits} bits"),
214 }
215 }
216
217 pub fn pack_factor(&self, dtype: DType) -> usize {
218 match self {
219 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
220 2 => IsqType::Q2K.pack_factor(dtype),
221 3 => IsqType::Q3K.pack_factor(dtype),
222 4 => IsqType::Q4K.pack_factor(dtype),
223 5 => IsqType::Q5K.pack_factor(dtype),
224 6 => IsqType::Q6K.pack_factor(dtype),
225 8 => IsqType::Q8_0.pack_factor(dtype),
226 other => panic!("Unexpected bits in `pack_factor` {other}"),
227 },
228 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
229 Self::Bitsandbytes {
230 bnb_4bit_quant_type: Some(_),
231 }
232 | Self::Bitsandbytes {
233 bnb_4bit_quant_type: None,
234 } => IsqType::Q4K.pack_factor(dtype),
235 }
236 }
237}
238
239#[derive(Debug, Clone)]
240pub enum QuantMethodConfig {
241 GptqAwq {
242 bits: i32,
243 use_exllama: bool,
244 q_weight: Tensor,
245 qzeros: Option<Tensor>,
246 scales: Tensor,
247 g_idx: Option<Tensor>,
248 bias: Option<Tensor>,
249 workspace: Option<Tensor>,
250 is_marlin: bool,
251 is_awq: bool,
252 },
253 Gguf {
254 q_weight: Arc<QTensor>,
255 b: Option<Tensor>,
256 },
257 Unquantized(Linear),
258 Hqq {
259 tensor: Tensor,
260 bits: HqqBits,
261 group_size: NonZeroUsize,
262 axis: HqqAxis,
263 optimization_steps: Option<usize>,
264 round_zeros: Option<bool>,
265 channel_wise: Option<bool>,
266 bias: Option<Tensor>,
267 },
268 Dummy,
269 FP8 {
270 lin: Linear,
271 dtype: DType,
272 },
273 Bnb {
274 weight: Tensor,
275 bias: Option<Tensor>,
276 params: BnbQuantParmas,
277 quant_ty: BnbQuantType,
278 },
279 BlockwiseFP8 {
280 weight: Tensor,
281 weight_scale_inv: Tensor,
282 bias: Option<Tensor>,
283 dequant_dtype: DType,
284 weight_block_size: Vec<usize>,
285 },
286 Afq {
287 weight: Tensor,
288 bias: Option<Tensor>,
289 bits: AfqBits,
290 group_size: AfqGroupSize,
291 },
292}
293
294pub struct MatMul;
297
298impl MatMul {
299 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
301 #[cfg(feature = "accelerate")]
302 {
303 let original_dtype = a.dtype();
304 a.to_dtype(DType::F32)?
305 .matmul(&b.to_dtype(DType::F32)?)?
306 .to_dtype(original_dtype)
307 }
308 #[cfg(not(feature = "accelerate"))]
309 {
310 if a.device().is_cpu() {
311 let original_dtype = a.dtype();
312 a.to_dtype(DType::F16)?
313 .matmul(&b.to_dtype(DType::F16)?)?
314 .to_dtype(original_dtype)
315 } else {
316 a.matmul(b)
317 }
318 }
319 }
320
321 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
324 self.matmul(a, b)? / scale
326 }
327
328 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
331 self.matmul(a, b)? * scale
333 }
334
335 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
337 matmul.forward(x)
338 }
339
340 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
342 matmul.forward(x)
343 }
344}
345
346#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
347pub enum IsqType {
348 Q4_0,
349 Q4_1,
350 Q5_0,
351 Q5_1,
352 Q8_0,
353 Q8_1,
354 Q2K,
355 Q3K,
356 Q4K,
357 Q5K,
358 Q6K,
359 Q8K,
360 HQQ8,
361 HQQ4,
362 F8E4M3,
366 AFQ8,
367 AFQ6,
368 AFQ4,
369 AFQ3,
370 AFQ2,
371}
372
373impl IsqType {
374 pub fn pack_factor(&self, dtype: DType) -> usize {
377 match self {
378 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
379 .div_ceil(GgmlDType::Q4_0.type_size()),
380 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
381 .div_ceil(GgmlDType::Q4_1.type_size()),
382 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
383 .div_ceil(GgmlDType::Q5_0.type_size()),
384 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
385 .div_ceil(GgmlDType::Q5_1.type_size()),
386 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
387 .div_ceil(GgmlDType::Q8_0.type_size()),
388 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
389 .div_ceil(GgmlDType::Q8_1.type_size()),
390 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
391 .div_ceil(GgmlDType::Q2K.type_size()),
392 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
393 .div_ceil(GgmlDType::Q3K.type_size()),
394 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
395 .div_ceil(GgmlDType::Q4K.type_size()),
396 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
397 .div_ceil(GgmlDType::Q5K.type_size()),
398 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
399 .div_ceil(GgmlDType::Q6K.type_size()),
400 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
401 .div_ceil(GgmlDType::Q8K.type_size()),
402 Self::HQQ4 => 4,
404 Self::HQQ8 => 2,
405 Self::F8E4M3 => 2,
406 }
407 }
408
409 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
410 match self {
411 IsqType::HQQ4
413 | IsqType::HQQ8
414 | IsqType::AFQ2
415 | IsqType::AFQ3
416 | IsqType::AFQ4
417 | IsqType::AFQ6
418 | IsqType::AFQ8 => {
419 Some(1.try_into().unwrap())
421 }
422 IsqType::F8E4M3 => None,
423 IsqType::Q2K
424 | IsqType::Q3K
425 | IsqType::Q4K
426 | IsqType::Q4_0
427 | IsqType::Q4_1
428 | IsqType::Q5K
429 | IsqType::Q5_0
430 | IsqType::Q5_1
431 | IsqType::Q6K
432 | IsqType::Q8K
433 | IsqType::Q8_0
434 | IsqType::Q8_1 => None,
435 }
436 }
437}
438
439impl TryFrom<IsqType> for GgmlDType {
440 type Error = candle_core::Error;
441
442 fn try_from(value: IsqType) -> Result<Self> {
443 let tp = match value {
444 IsqType::Q2K => Self::Q2K,
445 IsqType::Q3K => Self::Q3K,
446 IsqType::Q4K => Self::Q4K,
447 IsqType::Q4_0 => Self::Q4_0,
448 IsqType::Q4_1 => Self::Q4_1,
449 IsqType::Q5K => Self::Q5K,
450 IsqType::Q5_0 => Self::Q5_0,
451 IsqType::Q5_1 => Self::Q5_1,
452 IsqType::Q6K => Self::Q6K,
453 IsqType::Q8K => Self::Q8K,
454 IsqType::Q8_0 => Self::Q8_0,
455 IsqType::Q8_1 => Self::Q8_1,
456 _ => candle_core::bail!("Expected valid GGML ISQ type."),
457 };
458 #[cfg(feature = "cuda")]
459 {
460 if !matches!(
461 tp,
462 GgmlDType::Q4_0
463 | GgmlDType::Q4_1
464 | GgmlDType::Q5_0
465 | GgmlDType::Q5_1
466 | GgmlDType::Q8_0
467 | GgmlDType::Q2K
468 | GgmlDType::Q3K
469 | GgmlDType::Q4K
470 | GgmlDType::Q5K
471 | GgmlDType::Q6K
472 ) {
473 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
474 }
475 }
476 Ok(tp)
477 }
478}
479
480impl TryFrom<GgmlDType> for IsqType {
481 type Error = candle_core::Error;
482
483 fn try_from(value: GgmlDType) -> Result<Self> {
484 match value {
485 GgmlDType::Q2K => Ok(Self::Q2K),
486 GgmlDType::Q3K => Ok(Self::Q3K),
487 GgmlDType::Q4K => Ok(Self::Q4K),
488 GgmlDType::Q5K => Ok(Self::Q5K),
489 GgmlDType::Q6K => Ok(Self::Q6K),
490 GgmlDType::Q4_0 => Ok(Self::Q4_0),
491 GgmlDType::Q4_1 => Ok(Self::Q4_1),
492 GgmlDType::Q5_0 => Ok(Self::Q5_0),
493 GgmlDType::Q5_1 => Ok(Self::Q5_1),
494 GgmlDType::Q8_0 => Ok(Self::Q8_0),
495 GgmlDType::Q8_1 => Ok(Self::Q8_1),
496 GgmlDType::Q8K => Ok(Self::Q8K),
497 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
498 candle_core::bail!("Expected valid GGML ISQ type.")
499 }
500 }
501 }
502}
503
504#[derive(Debug, Clone, Copy)]
505pub enum QuantizedSerdeType {
506 Gguf = 0,
507 Unquant = 1,
508 Hqq = 2,
509 Fp8 = 3,
510 Afq = 4,
511}
512
513impl TryFrom<usize> for QuantizedSerdeType {
514 type Error = candle_core::Error;
515 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
516 match value {
517 0 => Ok(Self::Gguf),
518 1 => Ok(Self::Unquant),
519 2 => Ok(Self::Hqq),
520 3 => Ok(Self::Fp8),
521 4 => Ok(Self::Afq),
522 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
523 }
524 }
525}
526
527pub trait QuantizedSerde {
528 fn name(&self) -> &'static str;
529 fn isq_serde_supported(&self) -> bool {
530 false
531 }
532 fn serialize(&self) -> Result<Cow<[u8]>> {
533 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
534 }
535 fn deserialize(
536 _data: Cow<[u8]>,
537 _device: &Device,
538 _comm: &Arc<crate::Comm>,
539 _guard: QuantizeOntoGuard,
540 ) -> Result<Arc<dyn QuantMethod>>
541 where
542 Self: Sized,
543 {
544 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
545 }
546 fn deserialize_ext_bias(
547 _data: Cow<[u8]>,
548 _device: &Device,
549 _guard: QuantizeOntoGuard,
550 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
551 where
552 Self: Sized,
553 {
554 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
555 }
556 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<[u8]>> {
558 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
559 }
560}
561
562#[derive(Clone, Debug)]
564#[allow(unused)]
565pub struct QuantizeOntoGuard {
566 pub inner: Arc<Mutex<()>>,
567}
568
569pub enum QuantizeOntoDropGuard<'a> {
571 Real(MutexGuard<'a, ()>),
572 Fake,
573}
574
575impl Default for QuantizeOntoGuard {
576 fn default() -> Self {
577 Self::new()
578 }
579}
580
581impl QuantizeOntoGuard {
582 pub fn new() -> Self {
583 QuantizeOntoGuard {
584 inner: Arc::new(Mutex::new(())),
585 }
586 }
587
588 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
592 #[cfg(feature = "cuda")]
593 {
594 let _ = device;
595 QuantizeOntoDropGuard::Fake
596 }
597
598 #[cfg(not(feature = "cuda"))]
599 {
600 #[cfg(feature = "metal")]
601 if let Device::Metal(dev) = device {
602 dev.flush_command_buffer()
604 .expect("Failed to flush command buffer.");
605 }
606 #[cfg(not(feature = "metal"))]
607 let _ = device;
608
609 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
610 }
611 }
612}
613
614pub enum DistributedKind {
615 ColumnParallel,
616 RowParallel,
617 Replicated,
618}
619
620pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
622 fn new(method: QuantMethodConfig) -> Result<Self>
623 where
624 Self: Sized;
625
626 fn dequantize_w(&self) -> Result<Tensor>;
627
628 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
631 let original_ty = a.dtype();
632 let a = if let Some(t) = self.quantized_act_type() {
633 a.to_dtype(t)?
634 } else {
635 a.clone()
636 };
637 self.forward(&a)?.to_dtype(original_ty)
638 }
639
640 fn forward(&self, a: &Tensor) -> Result<Tensor>;
642
643 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
649 let original_ty = a.dtype();
650 let a = if let Some(t) = self.quantized_act_type() {
651 a.to_dtype(t)?
652 } else {
653 a.clone()
654 };
655 self.gather_forward(&a, indices)?.to_dtype(original_ty)
656 }
657
658 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
663 candle_core::bail!(
664 "{} does not support `gather_forward`. Please raise an issue.",
665 self.name()
666 )
667 }
668
669 fn quantized_act_type(&self) -> Option<DType>;
671
672 fn dtype_and_device(&self) -> (DType, Device);
674
675 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
677
678 fn apply_isq(
680 self: Arc<Self>,
681 dtype: Option<IsqType>,
682 device: Device,
683 n_quantized: &AtomicUsize,
684 imatrix_weight: Option<Vec<f32>>,
685 guard: QuantizeOntoGuard,
686 ) -> Result<Arc<dyn QuantMethod>>;
687
688 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
689 None
690 }
691
692 fn begin_track_stats(&mut self) -> Result<()> {
694 candle_core::bail!("`{}` does not support tracking stats.", self.name())
695 }
696
697 fn end_track_stats(&self) -> Result<Tensor> {
699 candle_core::bail!("`{}` does not support tracking stats.", self.name())
700 }
701
702 fn is_distributed(&self) -> Option<DistributedKind> {
703 None
704 }
705}
706
707impl Module for dyn QuantMethod {
708 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
709 Self::forward(self, xs)
710 }
711}
712
713pub fn linear_no_bias(
714 in_dim: usize,
715 out_dim: usize,
716 config: &Option<QuantizedConfig>,
717 vb: ShardedVarBuilder,
718) -> Result<Arc<dyn QuantMethod>> {
719 let base_vb = vb.clone();
720 let vb = if should_apply_immediate_isq(&vb) {
721 vb.set_device(Device::Cpu)
722 } else {
723 vb
724 };
725
726 let layer = if let Some(quant_conf) = &config {
727 match quant_conf {
728 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
729 QuantizedConfig::Fp8 { .. } => {
730 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
731 }
732 QuantizedConfig::Bitsandbytes { .. } => {
733 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
734 }
735 QuantizedConfig::Afq { .. } => {
736 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
737 }
738 }
739 } else {
740 if !vb.contains_tensor("weight") {
742 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
743 Arc::new(layer) as Arc<dyn QuantMethod>
744 } else {
745 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
746 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
747
748 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
749 Linear::new(weight, None),
750 ))?;
751 Arc::new(layer) as Arc<dyn QuantMethod>
752 }
753 };
754 apply_immediate_isq(layer, base_vb)
755}
756
757pub fn linear(
758 in_dim: usize,
759 out_dim: usize,
760 config: &Option<QuantizedConfig>,
761 vb: ShardedVarBuilder,
762) -> Result<Arc<dyn QuantMethod>> {
763 let base_vb = vb.clone();
764 let vb = if should_apply_immediate_isq(&vb) {
765 vb.set_device(Device::Cpu)
766 } else {
767 vb
768 };
769
770 let layer = if let Some(quant_conf) = &config {
771 match quant_conf {
772 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
773 QuantizedConfig::Fp8 { .. } => {
774 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
775 }
776 QuantizedConfig::Bitsandbytes { .. } => {
777 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
778 }
779 QuantizedConfig::Afq { .. } => {
780 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
781 }
782 }
783 } else {
784 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
786 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
787 Arc::new(layer) as Arc<dyn QuantMethod>
788 } else {
789 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
790 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
791 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
792
793 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
794 Linear::new(weight, Some(bias)),
795 ))?;
796 Arc::new(layer) as Arc<dyn QuantMethod>
797 }
798 };
799 apply_immediate_isq(layer, base_vb)
800}
801
802pub fn linear_b(
803 in_dim: usize,
804 out_dim: usize,
805 bias: bool,
806 config: &Option<QuantizedConfig>,
807 vb: ShardedVarBuilder,
808) -> Result<Arc<dyn QuantMethod>> {
809 if bias {
810 linear(in_dim, out_dim, config, vb)
811 } else {
812 linear_no_bias(in_dim, out_dim, config, vb)
813 }
814}