1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29pub mod rotary;
30pub mod safetensors;
31mod unquantized;
32mod utils;
33
34use gptq::gptq_linear;
35use lora::merge_lora_weights;
36use regex::Regex;
37pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
38
39pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
40pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
41pub use distributed::{
42 layers::{
43 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
44 ReplicatedLayer, RowParallelLayer,
45 },
46 socket::{Client, Server},
47 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
48};
49pub use dummy::DummyLayer;
50pub use fp8::FP8Linear;
51pub use gguf::GgufMatMul;
52pub use gptq::GptqLayer;
53pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
54pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
55pub use lora::{
56 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
57 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
58};
59pub use unquantized::UnquantLinear;
60pub use utils::isq::apply_immediate_isq;
61pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
62
63use candle_nn::{Linear, Module};
64use serde::{Deserialize, Deserializer, Serialize};
65
66#[derive(Clone, Debug)]
67pub struct ImmediateIsqParams {
68 pub guard: QuantizeOntoGuard,
69 pub ty: Option<IsqType>,
70 pub predicates: Vec<Regex>,
71}
72
73thread_local! {
74 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
75}
76
77pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
78 ENGINE_IMMEDIATE_ISQ.with(|cell| {
79 *cell.borrow_mut() = Some(ImmediateIsqParams {
80 guard: QuantizeOntoGuard::new(),
81 ty: isq,
82 predicates,
83 });
84 });
85}
86
87pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
88 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
89}
90
91pub fn clear_immediate_isq() {
92 ENGINE_IMMEDIATE_ISQ.with(|cell| {
93 *cell.borrow_mut() = None;
94 });
95}
96
97pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
98 let Some(immediate_isq) = get_immediate_isq() else {
99 return false;
100 };
101 let prefix = format!("{}.weight", vb.prefix());
103 immediate_isq.ty.is_some()
104 && immediate_isq
105 .predicates
106 .iter()
107 .any(|predicate| predicate.is_match(&prefix))
108}
109
110#[derive(Debug, Clone, Serialize)]
111#[serde(tag = "quant_method", rename_all = "lowercase")]
112pub enum QuantizedConfig {
113 GptqAwq {
114 bits: usize,
115 group_size: usize,
116 checkpoint_format: Option<String>,
117 is_awq: bool,
118 },
119 Fp8 {
120 weight_block_size: Vec<usize>,
121 },
122 Bitsandbytes {
123 bnb_4bit_quant_type: Option<String>,
124 },
125 Afq {
126 bits: usize,
127 group_size: usize,
128 },
129}
130
131#[derive(Deserialize)]
133struct RawConfig {
134 quant_method: Option<String>,
135 bits: Option<usize>,
136 group_size: Option<usize>,
137 checkpoint_format: Option<String>,
138 weight_block_size: Option<Vec<usize>>,
139 bnb_4bit_quant_type: Option<String>,
140}
141
142impl<'de> Deserialize<'de> for QuantizedConfig {
144 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
145 where
146 D: Deserializer<'de>,
147 {
148 let raw = RawConfig::deserialize(deserializer)?;
149
150 match &raw.quant_method {
151 Some(m) if m == "gptq" || m == "awq" => {
152 let bits = raw
153 .bits
154 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
155 let group_size = raw
156 .group_size
157 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
158 Ok(QuantizedConfig::GptqAwq {
159 bits,
160 group_size,
161 checkpoint_format: raw.checkpoint_format,
162 is_awq: m == "awq",
163 })
164 }
165 Some(m) if m == "fp8" => {
166 let weight_block_size = raw
167 .weight_block_size
168 .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
169 Ok(QuantizedConfig::Fp8 { weight_block_size })
170 }
171 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
172 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
173 }),
174 Some(m) if m == "afq" => {
175 let bits = raw
176 .bits
177 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
178 let group_size = raw
179 .group_size
180 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
181 Ok(QuantizedConfig::Afq { bits, group_size })
182 }
183 None => {
184 let bits = raw
185 .bits
186 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
187 let group_size = raw
188 .group_size
189 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
190 Ok(QuantizedConfig::Afq { bits, group_size })
191 }
192 Some(unknown_method) => {
193 Err(serde::de::Error::custom(format!(
194 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
195 )))
196 },
197 }
198 }
199}
200
201impl QuantizedConfig {
202 pub fn name(&self) -> &'static str {
203 match self {
204 Self::GptqAwq { .. } => "gptq",
205 Self::Fp8 { .. } => "fp8",
206 Self::Bitsandbytes { .. } => "bitsandbytes",
207 Self::Afq { .. } => "afq",
208 }
209 }
210
211 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
212 match self {
213 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
214 Self::Fp8 { .. } => "8 bits".to_string(),
215 Self::Bitsandbytes {
216 bnb_4bit_quant_type: Some(_),
217 } => "4 bits".to_string(),
218 Self::Bitsandbytes {
219 bnb_4bit_quant_type: None,
220 } => "8 bits".to_string(),
221 Self::Afq { bits, .. } => format!("{bits} bits"),
222 }
223 }
224
225 pub fn pack_factor(&self, dtype: DType) -> usize {
226 match self {
227 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
228 2 => IsqType::Q2K.pack_factor(dtype),
229 3 => IsqType::Q3K.pack_factor(dtype),
230 4 => IsqType::Q4K.pack_factor(dtype),
231 5 => IsqType::Q5K.pack_factor(dtype),
232 6 => IsqType::Q6K.pack_factor(dtype),
233 8 => IsqType::Q8_0.pack_factor(dtype),
234 other => panic!("Unexpected bits in `pack_factor` {other}"),
235 },
236 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
237 Self::Bitsandbytes {
238 bnb_4bit_quant_type: Some(_),
239 }
240 | Self::Bitsandbytes {
241 bnb_4bit_quant_type: None,
242 } => IsqType::Q4K.pack_factor(dtype),
243 }
244 }
245}
246
247#[derive(Debug, Clone)]
248pub enum QuantMethodConfig {
249 GptqAwq {
250 bits: i32,
251 use_exllama: bool,
252 q_weight: Tensor,
253 qzeros: Option<Tensor>,
254 scales: Tensor,
255 g_idx: Option<Tensor>,
256 bias: Option<Tensor>,
257 workspace: Option<Tensor>,
258 is_marlin: bool,
259 is_awq: bool,
260 },
261 Gguf {
262 q_weight: Arc<QTensor>,
263 b: Option<Tensor>,
264 },
265 Unquantized(Linear),
266 Hqq {
267 tensor: Tensor,
268 bits: HqqBits,
269 group_size: NonZeroUsize,
270 axis: HqqAxis,
271 optimization_steps: Option<usize>,
272 round_zeros: Option<bool>,
273 channel_wise: Option<bool>,
274 bias: Option<Tensor>,
275 },
276 Dummy,
277 FP8 {
278 lin: Linear,
279 dtype: DType,
280 },
281 Bnb {
282 weight: Tensor,
283 bias: Option<Tensor>,
284 params: BnbQuantParmas,
285 quant_ty: BnbQuantType,
286 },
287 BlockwiseFP8 {
288 weight: Tensor,
289 weight_scale_inv: Tensor,
290 bias: Option<Tensor>,
291 dequant_dtype: DType,
292 weight_block_size: Vec<usize>,
293 },
294 Afq {
295 weight: Tensor,
296 bias: Option<Tensor>,
297 bits: AfqBits,
298 group_size: AfqGroupSize,
299 },
300}
301
302pub struct MatMul;
305
306impl MatMul {
307 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
309 #[cfg(feature = "accelerate")]
310 {
311 let original_dtype = a.dtype();
312 a.to_dtype(DType::F32)?
313 .matmul(&b.to_dtype(DType::F32)?)?
314 .to_dtype(original_dtype)
315 }
316 #[cfg(not(feature = "accelerate"))]
317 {
318 if a.device().is_cpu() {
319 let original_dtype = a.dtype();
320 a.to_dtype(DType::F16)?
321 .matmul(&b.to_dtype(DType::F16)?)?
322 .to_dtype(original_dtype)
323 } else {
324 a.matmul(b)
325 }
326 }
327 }
328
329 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
332 self.matmul(a, b)? / scale
334 }
335
336 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
339 self.matmul(a, b)? * scale
341 }
342
343 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
345 matmul.forward(x)
346 }
347
348 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
350 matmul.forward(x)
351 }
352}
353
354#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
355pub enum IsqType {
356 Q4_0,
357 Q4_1,
358 Q5_0,
359 Q5_1,
360 Q8_0,
361 Q8_1,
362 Q2K,
363 Q3K,
364 Q4K,
365 Q5K,
366 Q6K,
367 Q8K,
368 HQQ8,
369 HQQ4,
370 F8E4M3,
374 AFQ8,
375 AFQ6,
376 AFQ4,
377 AFQ3,
378 AFQ2,
379}
380
381impl IsqType {
382 pub fn pack_factor(&self, dtype: DType) -> usize {
385 match self {
386 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
387 .div_ceil(GgmlDType::Q4_0.type_size()),
388 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
389 .div_ceil(GgmlDType::Q4_1.type_size()),
390 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
391 .div_ceil(GgmlDType::Q5_0.type_size()),
392 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
393 .div_ceil(GgmlDType::Q5_1.type_size()),
394 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
395 .div_ceil(GgmlDType::Q8_0.type_size()),
396 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
397 .div_ceil(GgmlDType::Q8_1.type_size()),
398 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
399 .div_ceil(GgmlDType::Q2K.type_size()),
400 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
401 .div_ceil(GgmlDType::Q3K.type_size()),
402 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
403 .div_ceil(GgmlDType::Q4K.type_size()),
404 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
405 .div_ceil(GgmlDType::Q5K.type_size()),
406 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
407 .div_ceil(GgmlDType::Q6K.type_size()),
408 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
409 .div_ceil(GgmlDType::Q8K.type_size()),
410 Self::HQQ4 => 4,
412 Self::HQQ8 => 2,
413 Self::F8E4M3 => 2,
414 }
415 }
416
417 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
418 match self {
419 IsqType::HQQ4
421 | IsqType::HQQ8
422 | IsqType::AFQ2
423 | IsqType::AFQ3
424 | IsqType::AFQ4
425 | IsqType::AFQ6
426 | IsqType::AFQ8 => {
427 Some(1.try_into().unwrap())
429 }
430 IsqType::F8E4M3 => None,
431 IsqType::Q2K
432 | IsqType::Q3K
433 | IsqType::Q4K
434 | IsqType::Q4_0
435 | IsqType::Q4_1
436 | IsqType::Q5K
437 | IsqType::Q5_0
438 | IsqType::Q5_1
439 | IsqType::Q6K
440 | IsqType::Q8K
441 | IsqType::Q8_0
442 | IsqType::Q8_1 => None,
443 }
444 }
445}
446
447impl TryFrom<IsqType> for GgmlDType {
448 type Error = candle_core::Error;
449
450 fn try_from(value: IsqType) -> Result<Self> {
451 let tp = match value {
452 IsqType::Q2K => Self::Q2K,
453 IsqType::Q3K => Self::Q3K,
454 IsqType::Q4K => Self::Q4K,
455 IsqType::Q4_0 => Self::Q4_0,
456 IsqType::Q4_1 => Self::Q4_1,
457 IsqType::Q5K => Self::Q5K,
458 IsqType::Q5_0 => Self::Q5_0,
459 IsqType::Q5_1 => Self::Q5_1,
460 IsqType::Q6K => Self::Q6K,
461 IsqType::Q8K => Self::Q8K,
462 IsqType::Q8_0 => Self::Q8_0,
463 IsqType::Q8_1 => Self::Q8_1,
464 _ => candle_core::bail!("Expected valid GGML ISQ type."),
465 };
466 #[cfg(feature = "cuda")]
467 {
468 if !matches!(
469 tp,
470 GgmlDType::Q4_0
471 | GgmlDType::Q4_1
472 | GgmlDType::Q5_0
473 | GgmlDType::Q5_1
474 | GgmlDType::Q8_0
475 | GgmlDType::Q2K
476 | GgmlDType::Q3K
477 | GgmlDType::Q4K
478 | GgmlDType::Q5K
479 | GgmlDType::Q6K
480 ) {
481 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
482 }
483 }
484 Ok(tp)
485 }
486}
487
488impl TryFrom<GgmlDType> for IsqType {
489 type Error = candle_core::Error;
490
491 fn try_from(value: GgmlDType) -> Result<Self> {
492 match value {
493 GgmlDType::Q2K => Ok(Self::Q2K),
494 GgmlDType::Q3K => Ok(Self::Q3K),
495 GgmlDType::Q4K => Ok(Self::Q4K),
496 GgmlDType::Q5K => Ok(Self::Q5K),
497 GgmlDType::Q6K => Ok(Self::Q6K),
498 GgmlDType::Q4_0 => Ok(Self::Q4_0),
499 GgmlDType::Q4_1 => Ok(Self::Q4_1),
500 GgmlDType::Q5_0 => Ok(Self::Q5_0),
501 GgmlDType::Q5_1 => Ok(Self::Q5_1),
502 GgmlDType::Q8_0 => Ok(Self::Q8_0),
503 GgmlDType::Q8_1 => Ok(Self::Q8_1),
504 GgmlDType::Q8K => Ok(Self::Q8K),
505 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
506 candle_core::bail!("Expected valid GGML ISQ type.")
507 }
508 }
509 }
510}
511
512#[derive(Debug, Clone, Copy)]
513pub enum QuantizedSerdeType {
514 Gguf = 0,
515 Unquant = 1,
516 Hqq = 2,
517 Fp8 = 3,
518 Afq = 4,
519}
520
521impl TryFrom<usize> for QuantizedSerdeType {
522 type Error = candle_core::Error;
523 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
524 match value {
525 0 => Ok(Self::Gguf),
526 1 => Ok(Self::Unquant),
527 2 => Ok(Self::Hqq),
528 3 => Ok(Self::Fp8),
529 4 => Ok(Self::Afq),
530 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
531 }
532 }
533}
534
535pub trait QuantizedSerde {
536 fn name(&self) -> &'static str;
537 fn isq_serde_supported(&self) -> bool {
538 false
539 }
540 fn serialize(&self) -> Result<Cow<[u8]>> {
541 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
542 }
543 fn deserialize(
544 _data: Cow<[u8]>,
545 _device: &Device,
546 _comm: &Arc<crate::Comm>,
547 _guard: QuantizeOntoGuard,
548 ) -> Result<Arc<dyn QuantMethod>>
549 where
550 Self: Sized,
551 {
552 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
553 }
554 fn deserialize_ext_bias(
555 _data: Cow<[u8]>,
556 _device: &Device,
557 _guard: QuantizeOntoGuard,
558 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
559 where
560 Self: Sized,
561 {
562 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
563 }
564 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<[u8]>> {
566 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
567 }
568}
569
570#[derive(Clone, Debug)]
572#[allow(unused)]
573pub struct QuantizeOntoGuard {
574 pub inner: Arc<Mutex<()>>,
575}
576
577pub enum QuantizeOntoDropGuard<'a> {
579 Real(MutexGuard<'a, ()>),
580 Fake,
581}
582
583impl Default for QuantizeOntoGuard {
584 fn default() -> Self {
585 Self::new()
586 }
587}
588
589impl QuantizeOntoGuard {
590 pub fn new() -> Self {
591 QuantizeOntoGuard {
592 inner: Arc::new(Mutex::new(())),
593 }
594 }
595
596 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
600 #[cfg(feature = "cuda")]
601 {
602 let _ = device;
603 QuantizeOntoDropGuard::Fake
604 }
605
606 #[cfg(not(feature = "cuda"))]
607 {
608 #[cfg(feature = "metal")]
609 if let Device::Metal(dev) = device {
610 dev.flush_command_buffer()
612 .expect("Failed to flush command buffer.");
613 }
614 #[cfg(not(feature = "metal"))]
615 let _ = device;
616
617 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
618 }
619 }
620}
621
622pub enum DistributedKind {
623 ColumnParallel,
624 RowParallel,
625 Replicated,
626}
627
628pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
630 fn new(method: QuantMethodConfig) -> Result<Self>
631 where
632 Self: Sized;
633
634 fn dequantize_w(&self) -> Result<Tensor>;
635
636 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
639 let original_ty = a.dtype();
640 let a = if let Some(t) = self.quantized_act_type() {
641 a.to_dtype(t)?
642 } else {
643 a.clone()
644 };
645 self.forward(&a)?.to_dtype(original_ty)
646 }
647
648 fn forward(&self, a: &Tensor) -> Result<Tensor>;
650
651 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
657 let original_ty = a.dtype();
658 let a = if let Some(t) = self.quantized_act_type() {
659 a.to_dtype(t)?
660 } else {
661 a.clone()
662 };
663 self.gather_forward(&a, indices)?.to_dtype(original_ty)
664 }
665
666 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
671 candle_core::bail!(
672 "{} does not support `gather_forward`. Please raise an issue.",
673 self.name()
674 )
675 }
676
677 fn quantized_act_type(&self) -> Option<DType>;
679
680 fn dtype_and_device(&self) -> (DType, Device);
682
683 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
685
686 fn apply_isq(
688 self: Arc<Self>,
689 dtype: Option<IsqType>,
690 device: Device,
691 n_quantized: &AtomicUsize,
692 imatrix_weight: Option<Vec<f32>>,
693 guard: QuantizeOntoGuard,
694 ) -> Result<Arc<dyn QuantMethod>>;
695
696 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
697 None
698 }
699
700 fn begin_track_stats(&mut self) -> Result<()> {
702 candle_core::bail!("`{}` does not support tracking stats.", self.name())
703 }
704
705 fn end_track_stats(&self) -> Result<Tensor> {
707 candle_core::bail!("`{}` does not support tracking stats.", self.name())
708 }
709
710 fn is_distributed(&self) -> Option<DistributedKind> {
711 None
712 }
713}
714
715impl Module for dyn QuantMethod {
716 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
717 Self::forward(self, xs)
718 }
719}
720
721pub fn linear_no_bias(
722 in_dim: usize,
723 out_dim: usize,
724 config: &Option<QuantizedConfig>,
725 vb: ShardedVarBuilder,
726) -> Result<Arc<dyn QuantMethod>> {
727 let base_vb = vb.clone();
728 let vb = if should_apply_immediate_isq(&vb) {
729 vb.set_device(Device::Cpu)
730 } else {
731 vb
732 };
733
734 let layer = if let Some(quant_conf) = &config {
735 match quant_conf {
736 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
737 QuantizedConfig::Fp8 { .. } => {
738 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
739 }
740 QuantizedConfig::Bitsandbytes { .. } => {
741 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
742 }
743 QuantizedConfig::Afq { .. } => {
744 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
745 }
746 }
747 } else {
748 if !vb.contains_tensor("weight") {
750 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
751 Arc::new(layer) as Arc<dyn QuantMethod>
752 } else {
753 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
754 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
755
756 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
757 Linear::new(weight, None),
758 ))?;
759 Arc::new(layer) as Arc<dyn QuantMethod>
760 }
761 };
762 apply_immediate_isq(layer, base_vb)
763}
764
765pub fn linear(
766 in_dim: usize,
767 out_dim: usize,
768 config: &Option<QuantizedConfig>,
769 vb: ShardedVarBuilder,
770) -> Result<Arc<dyn QuantMethod>> {
771 let base_vb = vb.clone();
772 let vb = if should_apply_immediate_isq(&vb) {
773 vb.set_device(Device::Cpu)
774 } else {
775 vb
776 };
777
778 let layer = if let Some(quant_conf) = &config {
779 match quant_conf {
780 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
781 QuantizedConfig::Fp8 { .. } => {
782 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
783 }
784 QuantizedConfig::Bitsandbytes { .. } => {
785 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
786 }
787 QuantizedConfig::Afq { .. } => {
788 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
789 }
790 }
791 } else {
792 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
794 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
795 Arc::new(layer) as Arc<dyn QuantMethod>
796 } else {
797 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
798 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
799 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
800
801 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
802 Linear::new(weight, Some(bias)),
803 ))?;
804 Arc::new(layer) as Arc<dyn QuantMethod>
805 }
806 };
807 apply_immediate_isq(layer, base_vb)
808}
809
810pub fn linear_b(
811 in_dim: usize,
812 out_dim: usize,
813 bias: bool,
814 config: &Option<QuantizedConfig>,
815 vb: ShardedVarBuilder,
816) -> Result<Arc<dyn QuantMethod>> {
817 if bias {
818 linear(in_dim, out_dim, config, vb)
819 } else {
820 linear_no_bias(in_dim, out_dim, config, vb)
821 }
822}