1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
44pub use blockwise_fp8::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
45pub use distributed::{
46 layers::{
47 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
48 ReplicatedLayer, RowParallelLayer,
49 },
50 socket::{Client, Server},
51 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
52};
53pub use dummy::DummyLayer;
54pub use fp8::FP8Linear;
55pub use gguf::GgufMatMul;
56pub use gptq::GptqLayer;
57pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
58pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
59pub use lora::{
60 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
61 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
62};
63pub use mxfp4::MXFP4Layer;
64pub use unquantized::UnquantLinear;
65pub use utils::isq::apply_immediate_isq;
66pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
67pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
68
69use candle_nn::{Conv1d, Conv2d, Linear, Module};
70use serde::{Deserialize, Deserializer, Serialize};
71
72#[derive(Clone, Debug)]
73pub struct ImmediateIsqParams {
74 pub guard: QuantizeOntoGuard,
75 pub ty: Option<IsqType>,
76 pub predicates: Vec<Regex>,
77}
78
79thread_local! {
80 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
81}
82
83pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
84 ENGINE_IMMEDIATE_ISQ.with(|cell| {
85 *cell.borrow_mut() = Some(ImmediateIsqParams {
86 guard: QuantizeOntoGuard::new(),
87 ty: isq,
88 predicates,
89 });
90 });
91}
92
93pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
94 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
95}
96
97pub fn clear_immediate_isq() {
98 ENGINE_IMMEDIATE_ISQ.with(|cell| {
99 *cell.borrow_mut() = None;
100 });
101}
102
103pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
104 let Some(immediate_isq) = get_immediate_isq() else {
105 return false;
106 };
107 let prefix = format!("{}.weight", vb.prefix());
109 immediate_isq.ty.is_some()
110 && immediate_isq
111 .predicates
112 .iter()
113 .any(|predicate| predicate.is_match(&prefix))
114}
115
116#[derive(Debug, Clone, Serialize)]
117#[serde(tag = "quant_method", rename_all = "lowercase")]
118pub enum QuantizedConfig {
119 GptqAwq {
120 bits: usize,
121 group_size: usize,
122 checkpoint_format: Option<String>,
123 is_awq: bool,
124 },
125 Fp8 {
126 weight_block_size: Vec<usize>,
127 },
128 Bitsandbytes {
129 bnb_4bit_quant_type: Option<String>,
130 },
131 Afq {
132 bits: usize,
133 group_size: usize,
134 },
135 MXFP4 {},
136}
137
138#[derive(Deserialize)]
140struct RawConfig {
141 quant_method: Option<String>,
142 bits: Option<usize>,
143 group_size: Option<usize>,
144 checkpoint_format: Option<String>,
145 weight_block_size: Option<Vec<usize>>,
146 bnb_4bit_quant_type: Option<String>,
147}
148
149impl<'de> Deserialize<'de> for QuantizedConfig {
151 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
152 where
153 D: Deserializer<'de>,
154 {
155 let raw = RawConfig::deserialize(deserializer)?;
156
157 match &raw.quant_method {
158 Some(m) if m == "gptq" || m == "awq" => {
159 let bits = raw
160 .bits
161 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
162 let group_size = raw
163 .group_size
164 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
165 Ok(QuantizedConfig::GptqAwq {
166 bits,
167 group_size,
168 checkpoint_format: raw.checkpoint_format,
169 is_awq: m == "awq",
170 })
171 }
172 Some(m) if m == "fp8" => {
173 let weight_block_size = raw
174 .weight_block_size
175 .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
176 Ok(QuantizedConfig::Fp8 { weight_block_size })
177 }
178 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
179 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
180 }),
181 Some(m) if m == "afq" => {
182 let bits = raw
183 .bits
184 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
185 let group_size = raw
186 .group_size
187 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
188 Ok(QuantizedConfig::Afq { bits, group_size })
189 }
190 Some(m) if m == "mxfp4" => {
191 Ok(QuantizedConfig::MXFP4 { })
192 }
193 None => {
194 let bits = raw
195 .bits
196 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
197 let group_size = raw
198 .group_size
199 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
200 Ok(QuantizedConfig::Afq { bits, group_size })
201 }
202 Some(unknown_method) => {
203 Err(serde::de::Error::custom(format!(
204 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
205 )))
206 },
207 }
208 }
209}
210
211impl QuantizedConfig {
212 pub fn name(&self) -> &'static str {
213 match self {
214 Self::GptqAwq { .. } => "gptq",
215 Self::Fp8 { .. } => "fp8",
216 Self::Bitsandbytes { .. } => "bitsandbytes",
217 Self::Afq { .. } => "afq",
218 Self::MXFP4 { .. } => "mxfp4",
219 }
220 }
221
222 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
223 match self {
224 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
225 Self::Fp8 { .. } => "8 bits".to_string(),
226 Self::Bitsandbytes {
227 bnb_4bit_quant_type: Some(_),
228 } => "4 bits".to_string(),
229 Self::Bitsandbytes {
230 bnb_4bit_quant_type: None,
231 } => "8 bits".to_string(),
232 Self::Afq { bits, .. } => format!("{bits} bits"),
233 Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
234 }
235 }
236
237 pub fn pack_factor(&self, dtype: DType) -> usize {
238 match self {
239 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
240 2 => IsqType::Q2K.pack_factor(dtype),
241 3 => IsqType::Q3K.pack_factor(dtype),
242 4 => IsqType::Q4K.pack_factor(dtype),
243 5 => IsqType::Q5K.pack_factor(dtype),
244 6 => IsqType::Q6K.pack_factor(dtype),
245 8 => IsqType::Q8_0.pack_factor(dtype),
246 40 => 4, other => panic!("Unexpected bits in `pack_factor` {other}"),
248 },
249 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
250 Self::Bitsandbytes {
251 bnb_4bit_quant_type: Some(_),
252 }
253 | Self::Bitsandbytes {
254 bnb_4bit_quant_type: None,
255 } => IsqType::Q4K.pack_factor(dtype),
256 Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
257 }
258 }
259}
260
261#[derive(Debug, Clone)]
262pub enum QuantMethodConfig {
263 GptqAwq {
264 bits: i32,
265 use_exllama: bool,
266 q_weight: Tensor,
267 qzeros: Option<Tensor>,
268 scales: Tensor,
269 g_idx: Option<Tensor>,
270 bias: Option<Tensor>,
271 workspace: Option<Tensor>,
272 is_marlin: bool,
273 is_awq: bool,
274 },
275 Gguf {
276 q_weight: Arc<QTensor>,
277 b: Option<Tensor>,
278 },
279 Unquantized(Linear),
280 Hqq {
281 tensor: Tensor,
282 bits: HqqBits,
283 group_size: NonZeroUsize,
284 axis: HqqAxis,
285 optimization_steps: Option<usize>,
286 round_zeros: Option<bool>,
287 channel_wise: Option<bool>,
288 bias: Option<Tensor>,
289 },
290 Dummy,
291 FP8 {
292 lin: Linear,
293 dtype: DType,
294 },
295 Bnb {
296 weight: Tensor,
297 bias: Option<Tensor>,
298 params: BnbQuantParams,
299 quant_ty: BnbQuantType,
300 },
301 BlockwiseFP8 {
302 weight: Tensor,
303 weight_scale_inv: Tensor,
304 bias: Option<Tensor>,
305 dequant_dtype: DType,
306 weight_block_size: Vec<usize>,
307 },
308 Afq {
309 weight: Tensor,
310 bias: Option<Tensor>,
311 bits: AfqBits,
312 group_size: AfqGroupSize,
313 },
314 MXFP4 {
315 blocks: Tensor,
316 scales: Tensor,
317 bias: Option<Tensor>,
318 },
319}
320
321pub struct MatMul;
324
325impl MatMul {
326 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
328 #[cfg(feature = "accelerate")]
329 {
330 let original_dtype = a.dtype();
331 a.to_dtype(DType::F32)?
332 .matmul(&b.to_dtype(DType::F32)?)?
333 .to_dtype(original_dtype)
334 }
335 #[cfg(not(feature = "accelerate"))]
336 {
337 if a.device().is_cpu() {
338 let original_dtype = a.dtype();
339 a.to_dtype(DType::F16)?
340 .matmul(&b.to_dtype(DType::F16)?)?
341 .to_dtype(original_dtype)
342 } else {
343 a.matmul(b)
344 }
345 }
346 }
347
348 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
351 self.matmul(a, b)? / scale
353 }
354
355 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
358 self.matmul(a, b)? * scale
360 }
361
362 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
364 matmul.forward(x)
365 }
366
367 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
369 matmul.forward(x)
370 }
371}
372
373pub struct Convolution;
376
377impl Convolution {
378 pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
379 if x.device().is_cpu() {
380 let original_dtype = x.dtype();
381 Conv1d::new(
382 layer.weight().to_dtype(DType::F32)?,
383 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
384 *layer.config(),
385 )
386 .forward(&x.to_dtype(DType::F32)?)?
387 .to_dtype(original_dtype)
388 } else {
389 layer.forward(x)
390 }
391 }
392
393 pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
394 if x.device().is_cpu() {
395 let original_dtype = x.dtype();
396 Conv2d::new(
397 layer.weight().to_dtype(DType::F32)?,
398 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
399 *layer.config(),
400 )
401 .forward(&x.to_dtype(DType::F32)?)?
402 .to_dtype(original_dtype)
403 } else {
404 layer.forward(x)
405 }
406 }
407}
408
409#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
410pub enum IsqType {
411 Q4_0,
412 Q4_1,
413 Q5_0,
414 Q5_1,
415 Q8_0,
416 Q8_1,
417 Q2K,
418 Q3K,
419 Q4K,
420 Q5K,
421 Q6K,
422 Q8K,
423 HQQ8,
424 HQQ4,
425 F8E4M3,
429 AFQ8,
430 AFQ6,
431 AFQ4,
432 AFQ3,
433 AFQ2,
434}
435
436impl IsqType {
437 pub fn pack_factor(&self, dtype: DType) -> usize {
440 match self {
441 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
442 .div_ceil(GgmlDType::Q4_0.type_size()),
443 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
444 .div_ceil(GgmlDType::Q4_1.type_size()),
445 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
446 .div_ceil(GgmlDType::Q5_0.type_size()),
447 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
448 .div_ceil(GgmlDType::Q5_1.type_size()),
449 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
450 .div_ceil(GgmlDType::Q8_0.type_size()),
451 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
452 .div_ceil(GgmlDType::Q8_1.type_size()),
453 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
454 .div_ceil(GgmlDType::Q2K.type_size()),
455 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
456 .div_ceil(GgmlDType::Q3K.type_size()),
457 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
458 .div_ceil(GgmlDType::Q4K.type_size()),
459 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
460 .div_ceil(GgmlDType::Q5K.type_size()),
461 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
462 .div_ceil(GgmlDType::Q6K.type_size()),
463 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
464 .div_ceil(GgmlDType::Q8K.type_size()),
465 Self::HQQ4 => 4,
467 Self::HQQ8 => 2,
468 Self::F8E4M3 => 2,
469 }
470 }
471
472 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
473 match self {
474 IsqType::HQQ4
476 | IsqType::HQQ8
477 | IsqType::AFQ2
478 | IsqType::AFQ3
479 | IsqType::AFQ4
480 | IsqType::AFQ6
481 | IsqType::AFQ8 => {
482 Some(1.try_into().unwrap())
484 }
485 IsqType::F8E4M3 => None,
486 IsqType::Q2K
487 | IsqType::Q3K
488 | IsqType::Q4K
489 | IsqType::Q4_0
490 | IsqType::Q4_1
491 | IsqType::Q5K
492 | IsqType::Q5_0
493 | IsqType::Q5_1
494 | IsqType::Q6K
495 | IsqType::Q8K
496 | IsqType::Q8_0
497 | IsqType::Q8_1 => None,
498 }
499 }
500}
501
502impl TryFrom<IsqType> for GgmlDType {
503 type Error = candle_core::Error;
504
505 fn try_from(value: IsqType) -> Result<Self> {
506 let tp = match value {
507 IsqType::Q2K => Self::Q2K,
508 IsqType::Q3K => Self::Q3K,
509 IsqType::Q4K => Self::Q4K,
510 IsqType::Q4_0 => Self::Q4_0,
511 IsqType::Q4_1 => Self::Q4_1,
512 IsqType::Q5K => Self::Q5K,
513 IsqType::Q5_0 => Self::Q5_0,
514 IsqType::Q5_1 => Self::Q5_1,
515 IsqType::Q6K => Self::Q6K,
516 IsqType::Q8K => Self::Q8K,
517 IsqType::Q8_0 => Self::Q8_0,
518 IsqType::Q8_1 => Self::Q8_1,
519 _ => candle_core::bail!("Expected valid GGML ISQ type."),
520 };
521 #[cfg(feature = "cuda")]
522 {
523 if !matches!(
524 tp,
525 GgmlDType::Q4_0
526 | GgmlDType::Q4_1
527 | GgmlDType::Q5_0
528 | GgmlDType::Q5_1
529 | GgmlDType::Q8_0
530 | GgmlDType::Q2K
531 | GgmlDType::Q3K
532 | GgmlDType::Q4K
533 | GgmlDType::Q5K
534 | GgmlDType::Q6K
535 ) {
536 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
537 }
538 }
539 Ok(tp)
540 }
541}
542
543impl TryFrom<GgmlDType> for IsqType {
544 type Error = candle_core::Error;
545
546 fn try_from(value: GgmlDType) -> Result<Self> {
547 match value {
548 GgmlDType::Q2K => Ok(Self::Q2K),
549 GgmlDType::Q3K => Ok(Self::Q3K),
550 GgmlDType::Q4K => Ok(Self::Q4K),
551 GgmlDType::Q5K => Ok(Self::Q5K),
552 GgmlDType::Q6K => Ok(Self::Q6K),
553 GgmlDType::Q4_0 => Ok(Self::Q4_0),
554 GgmlDType::Q4_1 => Ok(Self::Q4_1),
555 GgmlDType::Q5_0 => Ok(Self::Q5_0),
556 GgmlDType::Q5_1 => Ok(Self::Q5_1),
557 GgmlDType::Q8_0 => Ok(Self::Q8_0),
558 GgmlDType::Q8_1 => Ok(Self::Q8_1),
559 GgmlDType::Q8K => Ok(Self::Q8K),
560 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
561 candle_core::bail!("Expected valid GGML ISQ type.")
562 }
563 }
564 }
565}
566
567#[derive(Debug, Clone, Copy)]
568pub enum QuantizedSerdeType {
569 Gguf = 0,
570 Unquant = 1,
571 Hqq = 2,
572 Fp8 = 3,
573 Afq = 4,
574}
575
576impl TryFrom<usize> for QuantizedSerdeType {
577 type Error = candle_core::Error;
578 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
579 match value {
580 0 => Ok(Self::Gguf),
581 1 => Ok(Self::Unquant),
582 2 => Ok(Self::Hqq),
583 3 => Ok(Self::Fp8),
584 4 => Ok(Self::Afq),
585 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
586 }
587 }
588}
589
590pub trait QuantizedSerde {
591 fn name(&self) -> &'static str;
592 fn isq_serde_supported(&self) -> bool {
593 false
594 }
595 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
596 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
597 }
598 fn deserialize(
599 _data: Cow<[u8]>,
600 _device: &Device,
601 _comm: &Arc<crate::Comm>,
602 _guard: QuantizeOntoGuard,
603 ) -> Result<Arc<dyn QuantMethod>>
604 where
605 Self: Sized,
606 {
607 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
608 }
609 fn deserialize_ext_bias(
610 _data: Cow<[u8]>,
611 _device: &Device,
612 _guard: QuantizeOntoGuard,
613 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
614 where
615 Self: Sized,
616 {
617 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
618 }
619 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
621 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
622 }
623}
624
625#[derive(Clone, Debug)]
627#[allow(unused)]
628pub struct QuantizeOntoGuard {
629 pub inner: Arc<Mutex<()>>,
630}
631
632pub enum QuantizeOntoDropGuard<'a> {
634 Real(MutexGuard<'a, ()>),
635 Fake,
636}
637
638impl Default for QuantizeOntoGuard {
639 fn default() -> Self {
640 Self::new()
641 }
642}
643
644impl QuantizeOntoGuard {
645 pub fn new() -> Self {
646 QuantizeOntoGuard {
647 inner: Arc::new(Mutex::new(())),
648 }
649 }
650
651 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
655 #[cfg(feature = "cuda")]
656 {
657 let _ = device;
658 QuantizeOntoDropGuard::Fake
659 }
660
661 #[cfg(not(feature = "cuda"))]
662 {
663 #[cfg(feature = "metal")]
664 if let Device::Metal(dev) = device {
665 dev.flush_command_buffer()
667 .expect("Failed to flush command buffer.");
668 }
669 #[cfg(not(feature = "metal"))]
670 let _ = device;
671
672 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
673 }
674 }
675}
676
677pub enum DistributedKind {
678 ColumnParallel,
679 RowParallel,
680 Replicated,
681}
682
683pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
685 fn new(method: QuantMethodConfig) -> Result<Self>
686 where
687 Self: Sized;
688
689 fn dequantize_w(&self) -> Result<Tensor>;
690
691 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
694 let original_ty = a.dtype();
695 let a = if let Some(t) = self.quantized_act_type() {
696 a.to_dtype(t)?
697 } else {
698 a.clone()
699 };
700 self.forward(&a)?.to_dtype(original_ty)
701 }
702
703 fn forward(&self, a: &Tensor) -> Result<Tensor>;
705
706 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
712 let original_ty = a.dtype();
713 let a = if let Some(t) = self.quantized_act_type() {
714 a.to_dtype(t)?
715 } else {
716 a.clone()
717 };
718 self.gather_forward(&a, indices)?.to_dtype(original_ty)
719 }
720
721 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
726 candle_core::bail!(
727 "{} does not support `gather_forward`. Please raise an issue.",
728 self.name()
729 )
730 }
731
732 fn quantized_act_type(&self) -> Option<DType>;
734
735 fn dtype_and_device(&self) -> (DType, Device);
737
738 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
740
741 fn apply_isq(
743 self: Arc<Self>,
744 dtype: Option<IsqType>,
745 device: Device,
746 n_quantized: &AtomicUsize,
747 imatrix_weight: Option<Vec<f32>>,
748 guard: QuantizeOntoGuard,
749 ) -> Result<Arc<dyn QuantMethod>>;
750
751 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
752 None
753 }
754
755 fn begin_track_stats(&mut self) -> Result<()> {
757 candle_core::bail!("`{}` does not support tracking stats.", self.name())
758 }
759
760 fn end_track_stats(&self) -> Result<Tensor> {
762 candle_core::bail!("`{}` does not support tracking stats.", self.name())
763 }
764
765 fn is_distributed(&self) -> Option<DistributedKind> {
766 None
767 }
768}
769
770impl Module for dyn QuantMethod {
771 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
772 Self::forward(self, xs)
773 }
774}
775
776pub fn linear_no_bias(
777 in_dim: usize,
778 out_dim: usize,
779 config: &Option<QuantizedConfig>,
780 vb: ShardedVarBuilder,
781) -> Result<Arc<dyn QuantMethod>> {
782 let base_vb = vb.clone();
783 let vb = if should_apply_immediate_isq(&vb) {
784 vb.set_device(Device::Cpu)
785 } else {
786 vb
787 };
788
789 let layer = if let Some(quant_conf) = &config {
790 match quant_conf {
791 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
792 QuantizedConfig::Fp8 { .. } => {
793 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
794 }
795 QuantizedConfig::Bitsandbytes { .. } => {
796 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
797 }
798 QuantizedConfig::Afq { .. } => {
799 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
800 }
801 QuantizedConfig::MXFP4 {} => {
802 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
803 }
804 }
805 } else {
806 if !vb.contains_tensor("weight") {
808 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
809 Arc::new(layer) as Arc<dyn QuantMethod>
810 } else {
811 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
812 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
813
814 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
815 Linear::new(weight, None),
816 ))?;
817 Arc::new(layer) as Arc<dyn QuantMethod>
818 }
819 };
820 apply_immediate_isq(layer, base_vb)
821}
822
823pub fn linear(
824 in_dim: usize,
825 out_dim: usize,
826 config: &Option<QuantizedConfig>,
827 vb: ShardedVarBuilder,
828) -> Result<Arc<dyn QuantMethod>> {
829 let base_vb = vb.clone();
830 let vb = if should_apply_immediate_isq(&vb) {
831 vb.set_device(Device::Cpu)
832 } else {
833 vb
834 };
835
836 let layer = if let Some(quant_conf) = &config {
837 match quant_conf {
838 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
839 QuantizedConfig::Fp8 { .. } => {
840 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
841 }
842 QuantizedConfig::Bitsandbytes { .. } => {
843 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
844 }
845 QuantizedConfig::Afq { .. } => {
846 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
847 }
848 QuantizedConfig::MXFP4 {} => {
849 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
850 }
851 }
852 } else {
853 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
855 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
856 Arc::new(layer) as Arc<dyn QuantMethod>
857 } else {
858 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
859 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
860 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
861
862 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
863 Linear::new(weight, Some(bias)),
864 ))?;
865 Arc::new(layer) as Arc<dyn QuantMethod>
866 }
867 };
868 apply_immediate_isq(layer, base_vb)
869}
870
871pub fn linear_b(
872 in_dim: usize,
873 out_dim: usize,
874 bias: bool,
875 config: &Option<QuantizedConfig>,
876 vb: ShardedVarBuilder,
877) -> Result<Arc<dyn QuantMethod>> {
878 if bias {
879 linear(in_dim, out_dim, config, vb)
880 } else {
881 linear_no_bias(in_dim, out_dim, config, vb)
882 }
883}