1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
44pub use blockwise_fp8::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
45pub use distributed::{
46 layers::{
47 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
48 ReplicatedLayer, RowParallelLayer,
49 },
50 socket::{Client, Server},
51 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
52};
53pub use dummy::DummyLayer;
54pub use fp8::FP8Linear;
55pub use gguf::GgufMatMul;
56pub use gptq::GptqLayer;
57pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
58pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
59pub use lora::{
60 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
61 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
62};
63pub use mxfp4::MXFP4Layer;
64pub use unquantized::UnquantLinear;
65pub use utils::isq::apply_immediate_isq;
66pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
67pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
68
69use candle_nn::{Linear, Module};
70use serde::{Deserialize, Deserializer, Serialize};
71
72#[derive(Clone, Debug)]
73pub struct ImmediateIsqParams {
74 pub guard: QuantizeOntoGuard,
75 pub ty: Option<IsqType>,
76 pub predicates: Vec<Regex>,
77}
78
79thread_local! {
80 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
81}
82
83pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
84 ENGINE_IMMEDIATE_ISQ.with(|cell| {
85 *cell.borrow_mut() = Some(ImmediateIsqParams {
86 guard: QuantizeOntoGuard::new(),
87 ty: isq,
88 predicates,
89 });
90 });
91}
92
93pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
94 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
95}
96
97pub fn clear_immediate_isq() {
98 ENGINE_IMMEDIATE_ISQ.with(|cell| {
99 *cell.borrow_mut() = None;
100 });
101}
102
103pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
104 let Some(immediate_isq) = get_immediate_isq() else {
105 return false;
106 };
107 let prefix = format!("{}.weight", vb.prefix());
109 immediate_isq.ty.is_some()
110 && immediate_isq
111 .predicates
112 .iter()
113 .any(|predicate| predicate.is_match(&prefix))
114}
115
116#[derive(Debug, Clone, Serialize)]
117#[serde(tag = "quant_method", rename_all = "lowercase")]
118pub enum QuantizedConfig {
119 GptqAwq {
120 bits: usize,
121 group_size: usize,
122 checkpoint_format: Option<String>,
123 is_awq: bool,
124 },
125 Fp8 {
126 weight_block_size: Vec<usize>,
127 },
128 Bitsandbytes {
129 bnb_4bit_quant_type: Option<String>,
130 },
131 Afq {
132 bits: usize,
133 group_size: usize,
134 },
135 MXFP4 {},
136}
137
138#[derive(Deserialize)]
140struct RawConfig {
141 quant_method: Option<String>,
142 bits: Option<usize>,
143 group_size: Option<usize>,
144 checkpoint_format: Option<String>,
145 weight_block_size: Option<Vec<usize>>,
146 bnb_4bit_quant_type: Option<String>,
147}
148
149impl<'de> Deserialize<'de> for QuantizedConfig {
151 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
152 where
153 D: Deserializer<'de>,
154 {
155 let raw = RawConfig::deserialize(deserializer)?;
156
157 match &raw.quant_method {
158 Some(m) if m == "gptq" || m == "awq" => {
159 let bits = raw
160 .bits
161 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
162 let group_size = raw
163 .group_size
164 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
165 Ok(QuantizedConfig::GptqAwq {
166 bits,
167 group_size,
168 checkpoint_format: raw.checkpoint_format,
169 is_awq: m == "awq",
170 })
171 }
172 Some(m) if m == "fp8" => {
173 let weight_block_size = raw
174 .weight_block_size
175 .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
176 Ok(QuantizedConfig::Fp8 { weight_block_size })
177 }
178 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
179 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
180 }),
181 Some(m) if m == "afq" => {
182 let bits = raw
183 .bits
184 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
185 let group_size = raw
186 .group_size
187 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
188 Ok(QuantizedConfig::Afq { bits, group_size })
189 }
190 Some(m) if m == "mxfp4" => {
191 Ok(QuantizedConfig::MXFP4 { })
192 }
193 None => {
194 let bits = raw
195 .bits
196 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
197 let group_size = raw
198 .group_size
199 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
200 Ok(QuantizedConfig::Afq { bits, group_size })
201 }
202 Some(unknown_method) => {
203 Err(serde::de::Error::custom(format!(
204 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
205 )))
206 },
207 }
208 }
209}
210
211impl QuantizedConfig {
212 pub fn name(&self) -> &'static str {
213 match self {
214 Self::GptqAwq { .. } => "gptq",
215 Self::Fp8 { .. } => "fp8",
216 Self::Bitsandbytes { .. } => "bitsandbytes",
217 Self::Afq { .. } => "afq",
218 Self::MXFP4 { .. } => "mxfp4",
219 }
220 }
221
222 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
223 match self {
224 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
225 Self::Fp8 { .. } => "8 bits".to_string(),
226 Self::Bitsandbytes {
227 bnb_4bit_quant_type: Some(_),
228 } => "4 bits".to_string(),
229 Self::Bitsandbytes {
230 bnb_4bit_quant_type: None,
231 } => "8 bits".to_string(),
232 Self::Afq { bits, .. } => format!("{bits} bits"),
233 Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
234 }
235 }
236
237 pub fn pack_factor(&self, dtype: DType) -> usize {
238 match self {
239 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
240 2 => IsqType::Q2K.pack_factor(dtype),
241 3 => IsqType::Q3K.pack_factor(dtype),
242 4 => IsqType::Q4K.pack_factor(dtype),
243 5 => IsqType::Q5K.pack_factor(dtype),
244 6 => IsqType::Q6K.pack_factor(dtype),
245 8 => IsqType::Q8_0.pack_factor(dtype),
246 40 => 4, other => panic!("Unexpected bits in `pack_factor` {other}"),
248 },
249 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
250 Self::Bitsandbytes {
251 bnb_4bit_quant_type: Some(_),
252 }
253 | Self::Bitsandbytes {
254 bnb_4bit_quant_type: None,
255 } => IsqType::Q4K.pack_factor(dtype),
256 Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
257 }
258 }
259}
260
261#[derive(Debug, Clone)]
262pub enum QuantMethodConfig {
263 GptqAwq {
264 bits: i32,
265 use_exllama: bool,
266 q_weight: Tensor,
267 qzeros: Option<Tensor>,
268 scales: Tensor,
269 g_idx: Option<Tensor>,
270 bias: Option<Tensor>,
271 workspace: Option<Tensor>,
272 is_marlin: bool,
273 is_awq: bool,
274 },
275 Gguf {
276 q_weight: Arc<QTensor>,
277 b: Option<Tensor>,
278 },
279 Unquantized(Linear),
280 Hqq {
281 tensor: Tensor,
282 bits: HqqBits,
283 group_size: NonZeroUsize,
284 axis: HqqAxis,
285 optimization_steps: Option<usize>,
286 round_zeros: Option<bool>,
287 channel_wise: Option<bool>,
288 bias: Option<Tensor>,
289 },
290 Dummy,
291 FP8 {
292 lin: Linear,
293 dtype: DType,
294 },
295 Bnb {
296 weight: Tensor,
297 bias: Option<Tensor>,
298 params: BnbQuantParmas,
299 quant_ty: BnbQuantType,
300 },
301 BlockwiseFP8 {
302 weight: Tensor,
303 weight_scale_inv: Tensor,
304 bias: Option<Tensor>,
305 dequant_dtype: DType,
306 weight_block_size: Vec<usize>,
307 },
308 Afq {
309 weight: Tensor,
310 bias: Option<Tensor>,
311 bits: AfqBits,
312 group_size: AfqGroupSize,
313 },
314 MXFP4 {
315 blocks: Tensor,
316 scales: Tensor,
317 bias: Option<Tensor>,
318 },
319}
320
321pub struct MatMul;
324
325impl MatMul {
326 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
328 #[cfg(feature = "accelerate")]
329 {
330 let original_dtype = a.dtype();
331 a.to_dtype(DType::F32)?
332 .matmul(&b.to_dtype(DType::F32)?)?
333 .to_dtype(original_dtype)
334 }
335 #[cfg(not(feature = "accelerate"))]
336 {
337 if a.device().is_cpu() {
338 let original_dtype = a.dtype();
339 a.to_dtype(DType::F16)?
340 .matmul(&b.to_dtype(DType::F16)?)?
341 .to_dtype(original_dtype)
342 } else {
343 a.matmul(b)
344 }
345 }
346 }
347
348 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
351 self.matmul(a, b)? / scale
353 }
354
355 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
358 self.matmul(a, b)? * scale
360 }
361
362 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
364 matmul.forward(x)
365 }
366
367 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
369 matmul.forward(x)
370 }
371}
372
373#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
374pub enum IsqType {
375 Q4_0,
376 Q4_1,
377 Q5_0,
378 Q5_1,
379 Q8_0,
380 Q8_1,
381 Q2K,
382 Q3K,
383 Q4K,
384 Q5K,
385 Q6K,
386 Q8K,
387 HQQ8,
388 HQQ4,
389 F8E4M3,
393 AFQ8,
394 AFQ6,
395 AFQ4,
396 AFQ3,
397 AFQ2,
398}
399
400impl IsqType {
401 pub fn pack_factor(&self, dtype: DType) -> usize {
404 match self {
405 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
406 .div_ceil(GgmlDType::Q4_0.type_size()),
407 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
408 .div_ceil(GgmlDType::Q4_1.type_size()),
409 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
410 .div_ceil(GgmlDType::Q5_0.type_size()),
411 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
412 .div_ceil(GgmlDType::Q5_1.type_size()),
413 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
414 .div_ceil(GgmlDType::Q8_0.type_size()),
415 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
416 .div_ceil(GgmlDType::Q8_1.type_size()),
417 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
418 .div_ceil(GgmlDType::Q2K.type_size()),
419 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
420 .div_ceil(GgmlDType::Q3K.type_size()),
421 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
422 .div_ceil(GgmlDType::Q4K.type_size()),
423 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
424 .div_ceil(GgmlDType::Q5K.type_size()),
425 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
426 .div_ceil(GgmlDType::Q6K.type_size()),
427 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
428 .div_ceil(GgmlDType::Q8K.type_size()),
429 Self::HQQ4 => 4,
431 Self::HQQ8 => 2,
432 Self::F8E4M3 => 2,
433 }
434 }
435
436 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
437 match self {
438 IsqType::HQQ4
440 | IsqType::HQQ8
441 | IsqType::AFQ2
442 | IsqType::AFQ3
443 | IsqType::AFQ4
444 | IsqType::AFQ6
445 | IsqType::AFQ8 => {
446 Some(1.try_into().unwrap())
448 }
449 IsqType::F8E4M3 => None,
450 IsqType::Q2K
451 | IsqType::Q3K
452 | IsqType::Q4K
453 | IsqType::Q4_0
454 | IsqType::Q4_1
455 | IsqType::Q5K
456 | IsqType::Q5_0
457 | IsqType::Q5_1
458 | IsqType::Q6K
459 | IsqType::Q8K
460 | IsqType::Q8_0
461 | IsqType::Q8_1 => None,
462 }
463 }
464}
465
466impl TryFrom<IsqType> for GgmlDType {
467 type Error = candle_core::Error;
468
469 fn try_from(value: IsqType) -> Result<Self> {
470 let tp = match value {
471 IsqType::Q2K => Self::Q2K,
472 IsqType::Q3K => Self::Q3K,
473 IsqType::Q4K => Self::Q4K,
474 IsqType::Q4_0 => Self::Q4_0,
475 IsqType::Q4_1 => Self::Q4_1,
476 IsqType::Q5K => Self::Q5K,
477 IsqType::Q5_0 => Self::Q5_0,
478 IsqType::Q5_1 => Self::Q5_1,
479 IsqType::Q6K => Self::Q6K,
480 IsqType::Q8K => Self::Q8K,
481 IsqType::Q8_0 => Self::Q8_0,
482 IsqType::Q8_1 => Self::Q8_1,
483 _ => candle_core::bail!("Expected valid GGML ISQ type."),
484 };
485 #[cfg(feature = "cuda")]
486 {
487 if !matches!(
488 tp,
489 GgmlDType::Q4_0
490 | GgmlDType::Q4_1
491 | GgmlDType::Q5_0
492 | GgmlDType::Q5_1
493 | GgmlDType::Q8_0
494 | GgmlDType::Q2K
495 | GgmlDType::Q3K
496 | GgmlDType::Q4K
497 | GgmlDType::Q5K
498 | GgmlDType::Q6K
499 ) {
500 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
501 }
502 }
503 Ok(tp)
504 }
505}
506
507impl TryFrom<GgmlDType> for IsqType {
508 type Error = candle_core::Error;
509
510 fn try_from(value: GgmlDType) -> Result<Self> {
511 match value {
512 GgmlDType::Q2K => Ok(Self::Q2K),
513 GgmlDType::Q3K => Ok(Self::Q3K),
514 GgmlDType::Q4K => Ok(Self::Q4K),
515 GgmlDType::Q5K => Ok(Self::Q5K),
516 GgmlDType::Q6K => Ok(Self::Q6K),
517 GgmlDType::Q4_0 => Ok(Self::Q4_0),
518 GgmlDType::Q4_1 => Ok(Self::Q4_1),
519 GgmlDType::Q5_0 => Ok(Self::Q5_0),
520 GgmlDType::Q5_1 => Ok(Self::Q5_1),
521 GgmlDType::Q8_0 => Ok(Self::Q8_0),
522 GgmlDType::Q8_1 => Ok(Self::Q8_1),
523 GgmlDType::Q8K => Ok(Self::Q8K),
524 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
525 candle_core::bail!("Expected valid GGML ISQ type.")
526 }
527 }
528 }
529}
530
531#[derive(Debug, Clone, Copy)]
532pub enum QuantizedSerdeType {
533 Gguf = 0,
534 Unquant = 1,
535 Hqq = 2,
536 Fp8 = 3,
537 Afq = 4,
538}
539
540impl TryFrom<usize> for QuantizedSerdeType {
541 type Error = candle_core::Error;
542 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
543 match value {
544 0 => Ok(Self::Gguf),
545 1 => Ok(Self::Unquant),
546 2 => Ok(Self::Hqq),
547 3 => Ok(Self::Fp8),
548 4 => Ok(Self::Afq),
549 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
550 }
551 }
552}
553
554pub trait QuantizedSerde {
555 fn name(&self) -> &'static str;
556 fn isq_serde_supported(&self) -> bool {
557 false
558 }
559 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
560 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
561 }
562 fn deserialize(
563 _data: Cow<[u8]>,
564 _device: &Device,
565 _comm: &Arc<crate::Comm>,
566 _guard: QuantizeOntoGuard,
567 ) -> Result<Arc<dyn QuantMethod>>
568 where
569 Self: Sized,
570 {
571 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
572 }
573 fn deserialize_ext_bias(
574 _data: Cow<[u8]>,
575 _device: &Device,
576 _guard: QuantizeOntoGuard,
577 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
578 where
579 Self: Sized,
580 {
581 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
582 }
583 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
585 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
586 }
587}
588
589#[derive(Clone, Debug)]
591#[allow(unused)]
592pub struct QuantizeOntoGuard {
593 pub inner: Arc<Mutex<()>>,
594}
595
596pub enum QuantizeOntoDropGuard<'a> {
598 Real(MutexGuard<'a, ()>),
599 Fake,
600}
601
602impl Default for QuantizeOntoGuard {
603 fn default() -> Self {
604 Self::new()
605 }
606}
607
608impl QuantizeOntoGuard {
609 pub fn new() -> Self {
610 QuantizeOntoGuard {
611 inner: Arc::new(Mutex::new(())),
612 }
613 }
614
615 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
619 #[cfg(feature = "cuda")]
620 {
621 let _ = device;
622 QuantizeOntoDropGuard::Fake
623 }
624
625 #[cfg(not(feature = "cuda"))]
626 {
627 #[cfg(feature = "metal")]
628 if let Device::Metal(dev) = device {
629 dev.flush_command_buffer()
631 .expect("Failed to flush command buffer.");
632 }
633 #[cfg(not(feature = "metal"))]
634 let _ = device;
635
636 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
637 }
638 }
639}
640
641pub enum DistributedKind {
642 ColumnParallel,
643 RowParallel,
644 Replicated,
645}
646
647pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
649 fn new(method: QuantMethodConfig) -> Result<Self>
650 where
651 Self: Sized;
652
653 fn dequantize_w(&self) -> Result<Tensor>;
654
655 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
658 let original_ty = a.dtype();
659 let a = if let Some(t) = self.quantized_act_type() {
660 a.to_dtype(t)?
661 } else {
662 a.clone()
663 };
664 self.forward(&a)?.to_dtype(original_ty)
665 }
666
667 fn forward(&self, a: &Tensor) -> Result<Tensor>;
669
670 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
676 let original_ty = a.dtype();
677 let a = if let Some(t) = self.quantized_act_type() {
678 a.to_dtype(t)?
679 } else {
680 a.clone()
681 };
682 self.gather_forward(&a, indices)?.to_dtype(original_ty)
683 }
684
685 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
690 candle_core::bail!(
691 "{} does not support `gather_forward`. Please raise an issue.",
692 self.name()
693 )
694 }
695
696 fn quantized_act_type(&self) -> Option<DType>;
698
699 fn dtype_and_device(&self) -> (DType, Device);
701
702 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
704
705 fn apply_isq(
707 self: Arc<Self>,
708 dtype: Option<IsqType>,
709 device: Device,
710 n_quantized: &AtomicUsize,
711 imatrix_weight: Option<Vec<f32>>,
712 guard: QuantizeOntoGuard,
713 ) -> Result<Arc<dyn QuantMethod>>;
714
715 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
716 None
717 }
718
719 fn begin_track_stats(&mut self) -> Result<()> {
721 candle_core::bail!("`{}` does not support tracking stats.", self.name())
722 }
723
724 fn end_track_stats(&self) -> Result<Tensor> {
726 candle_core::bail!("`{}` does not support tracking stats.", self.name())
727 }
728
729 fn is_distributed(&self) -> Option<DistributedKind> {
730 None
731 }
732}
733
734impl Module for dyn QuantMethod {
735 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
736 Self::forward(self, xs)
737 }
738}
739
740pub fn linear_no_bias(
741 in_dim: usize,
742 out_dim: usize,
743 config: &Option<QuantizedConfig>,
744 vb: ShardedVarBuilder,
745) -> Result<Arc<dyn QuantMethod>> {
746 let base_vb = vb.clone();
747 let vb = if should_apply_immediate_isq(&vb) {
748 vb.set_device(Device::Cpu)
749 } else {
750 vb
751 };
752
753 let layer = if let Some(quant_conf) = &config {
754 match quant_conf {
755 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
756 QuantizedConfig::Fp8 { .. } => {
757 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
758 }
759 QuantizedConfig::Bitsandbytes { .. } => {
760 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
761 }
762 QuantizedConfig::Afq { .. } => {
763 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
764 }
765 QuantizedConfig::MXFP4 {} => {
766 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
767 }
768 }
769 } else {
770 if !vb.contains_tensor("weight") {
772 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
773 Arc::new(layer) as Arc<dyn QuantMethod>
774 } else {
775 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
776 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
777
778 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
779 Linear::new(weight, None),
780 ))?;
781 Arc::new(layer) as Arc<dyn QuantMethod>
782 }
783 };
784 apply_immediate_isq(layer, base_vb)
785}
786
787pub fn linear(
788 in_dim: usize,
789 out_dim: usize,
790 config: &Option<QuantizedConfig>,
791 vb: ShardedVarBuilder,
792) -> Result<Arc<dyn QuantMethod>> {
793 let base_vb = vb.clone();
794 let vb = if should_apply_immediate_isq(&vb) {
795 vb.set_device(Device::Cpu)
796 } else {
797 vb
798 };
799
800 let layer = if let Some(quant_conf) = &config {
801 match quant_conf {
802 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
803 QuantizedConfig::Fp8 { .. } => {
804 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
805 }
806 QuantizedConfig::Bitsandbytes { .. } => {
807 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
808 }
809 QuantizedConfig::Afq { .. } => {
810 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
811 }
812 QuantizedConfig::MXFP4 {} => {
813 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
814 }
815 }
816 } else {
817 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
819 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
820 Arc::new(layer) as Arc<dyn QuantMethod>
821 } else {
822 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
823 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
824 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
825
826 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
827 Linear::new(weight, Some(bias)),
828 ))?;
829 Arc::new(layer) as Arc<dyn QuantMethod>
830 }
831 };
832 apply_immediate_isq(layer, base_vb)
833}
834
835pub fn linear_b(
836 in_dim: usize,
837 out_dim: usize,
838 bias: bool,
839 config: &Option<QuantizedConfig>,
840 vb: ShardedVarBuilder,
841) -> Result<Arc<dyn QuantMethod>> {
842 if bias {
843 linear(in_dim, out_dim, config, vb)
844 } else {
845 linear_no_bias(in_dim, out_dim, config, vb)
846 }
847}