1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
44pub use blockwise_fp8::{
45 blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
46};
47pub use distributed::{
48 layers::{
49 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
50 ReplicatedLayer, RowParallelLayer,
51 },
52 socket::{Client, Server},
53 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
54};
55pub use dummy::DummyLayer;
56pub use fp8::FP8Linear;
57pub use gguf::GgufMatMul;
58pub use gptq::GptqLayer;
59pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
60pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
61pub use lora::{
62 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
63 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
64};
65pub use mxfp4::MXFP4Layer;
66pub use unquantized::UnquantLinear;
67#[cfg(feature = "cuda")]
68pub use utils::gptoss_swiglu_fused;
69#[cfg(feature = "cuda")]
70pub use utils::gptoss_swiglu_interleaved;
71pub use utils::isq::apply_immediate_isq;
72#[cfg(feature = "cuda")]
73pub use utils::softmax_with_sinks;
74pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
75pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
76
77use candle_nn::{Conv1d, Conv2d, Linear, Module};
78use serde::{Deserialize, Deserializer, Serialize};
79
80#[derive(Clone, Debug)]
81pub struct ImmediateIsqParams {
82 pub guard: QuantizeOntoGuard,
83 pub ty: Option<IsqType>,
84 pub predicates: Vec<Regex>,
85 pub overrides: Vec<ImmediateIsqOverride>,
86}
87
88#[derive(Clone, Debug)]
89pub struct ImmediateIsqOverride {
90 pub predicate: Regex,
91 pub ty: Option<IsqType>,
92 pub device: Option<Device>,
93}
94
95#[derive(Clone, Debug)]
96pub struct ImmediateIsqMatch {
97 pub ty: IsqType,
98 pub device: Option<Device>,
99}
100
101thread_local! {
102 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
103}
104
105pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
106 set_immediate_isq_with_overrides(isq, predicates, Vec::new());
107}
108
109pub fn set_immediate_isq_with_overrides(
110 isq: Option<IsqType>,
111 predicates: Vec<Regex>,
112 overrides: Vec<ImmediateIsqOverride>,
113) {
114 ENGINE_IMMEDIATE_ISQ.with(|cell| {
115 *cell.borrow_mut() = Some(ImmediateIsqParams {
116 guard: QuantizeOntoGuard::new(),
117 ty: isq,
118 predicates,
119 overrides,
120 });
121 });
122}
123
124pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
125 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
126}
127
128pub fn clear_immediate_isq() {
129 ENGINE_IMMEDIATE_ISQ.with(|cell| {
130 *cell.borrow_mut() = None;
131 });
132}
133
134pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
135 immediate_isq_match(vb).is_some()
136}
137
138pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
139 let immediate_isq = get_immediate_isq()?;
140 let prefix = format!("{}.weight", vb.prefix());
142 resolve_immediate_isq(&immediate_isq, &prefix)
143}
144
145fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
146 if let Some(override_hit) = params
147 .overrides
148 .iter()
149 .find(|override_pred| override_pred.predicate.is_match(prefix))
150 {
151 if let Some(ty) = override_hit.ty.or(params.ty) {
152 return Some(ImmediateIsqMatch {
153 ty,
154 device: override_hit.device.clone(),
155 });
156 }
157 return None;
158 }
159
160 if let Some(ty) = params.ty {
161 if params
162 .predicates
163 .iter()
164 .any(|predicate| predicate.is_match(prefix))
165 {
166 return Some(ImmediateIsqMatch { ty, device: None });
167 }
168 }
169
170 None
171}
172
173#[derive(Debug, Clone, Serialize)]
174#[serde(tag = "quant_method", rename_all = "lowercase")]
175pub enum QuantizedConfig {
176 GptqAwq {
177 bits: usize,
178 group_size: usize,
179 checkpoint_format: Option<String>,
180 is_awq: bool,
181 },
182 Fp8 {
183 weight_block_size: Vec<usize>,
184 },
185 Bitsandbytes {
186 bnb_4bit_quant_type: Option<String>,
187 },
188 Afq {
189 bits: usize,
190 group_size: usize,
191 },
192 MXFP4 {},
193}
194
195#[derive(Deserialize)]
197struct RawConfig {
198 quant_method: Option<String>,
199 bits: Option<usize>,
200 group_size: Option<usize>,
201 checkpoint_format: Option<String>,
202 weight_block_size: Option<Vec<usize>>,
203 bnb_4bit_quant_type: Option<String>,
204}
205
206impl<'de> Deserialize<'de> for QuantizedConfig {
208 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
209 where
210 D: Deserializer<'de>,
211 {
212 let raw = RawConfig::deserialize(deserializer)?;
213
214 match &raw.quant_method {
215 Some(m) if m == "gptq" || m == "awq" => {
216 let bits = raw
217 .bits
218 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
219 let group_size = raw
220 .group_size
221 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
222 Ok(QuantizedConfig::GptqAwq {
223 bits,
224 group_size,
225 checkpoint_format: raw.checkpoint_format,
226 is_awq: m == "awq",
227 })
228 }
229 Some(m) if m == "fp8" => {
230 let weight_block_size = raw
231 .weight_block_size
232 .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
233 Ok(QuantizedConfig::Fp8 { weight_block_size })
234 }
235 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
236 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
237 }),
238 Some(m) if m == "afq" => {
239 let bits = raw
240 .bits
241 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
242 let group_size = raw
243 .group_size
244 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
245 Ok(QuantizedConfig::Afq { bits, group_size })
246 }
247 Some(m) if m == "mxfp4" => {
248 Ok(QuantizedConfig::MXFP4 { })
249 }
250 None => {
251 let bits = raw
252 .bits
253 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
254 let group_size = raw
255 .group_size
256 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
257 Ok(QuantizedConfig::Afq { bits, group_size })
258 }
259 Some(unknown_method) => {
260 Err(serde::de::Error::custom(format!(
261 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
262 )))
263 },
264 }
265 }
266}
267
268impl QuantizedConfig {
269 pub fn name(&self) -> &'static str {
270 match self {
271 Self::GptqAwq { .. } => "gptq",
272 Self::Fp8 { .. } => "fp8",
273 Self::Bitsandbytes { .. } => "bitsandbytes",
274 Self::Afq { .. } => "afq",
275 Self::MXFP4 { .. } => "mxfp4",
276 }
277 }
278
279 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
280 match self {
281 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
282 Self::Fp8 { .. } => "8 bits".to_string(),
283 Self::Bitsandbytes {
284 bnb_4bit_quant_type: Some(_),
285 } => "4 bits".to_string(),
286 Self::Bitsandbytes {
287 bnb_4bit_quant_type: None,
288 } => "8 bits".to_string(),
289 Self::Afq { bits, .. } => format!("{bits} bits"),
290 Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
291 }
292 }
293
294 pub fn pack_factor(&self, dtype: DType) -> usize {
295 match self {
296 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
297 2 => IsqType::Q2K.pack_factor(dtype),
298 3 => IsqType::Q3K.pack_factor(dtype),
299 4 => IsqType::Q4K.pack_factor(dtype),
300 5 => IsqType::Q5K.pack_factor(dtype),
301 6 => IsqType::Q6K.pack_factor(dtype),
302 8 => IsqType::Q8_0.pack_factor(dtype),
303 40 => 4, other => panic!("Unexpected bits in `pack_factor` {other}"),
305 },
306 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
307 Self::Bitsandbytes {
308 bnb_4bit_quant_type: Some(_),
309 }
310 | Self::Bitsandbytes {
311 bnb_4bit_quant_type: None,
312 } => IsqType::Q4K.pack_factor(dtype),
313 Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
314 }
315 }
316}
317
318#[derive(Debug, Clone)]
319pub enum QuantMethodConfig {
320 GptqAwq {
321 bits: i32,
322 use_exllama: bool,
323 q_weight: Tensor,
324 qzeros: Option<Tensor>,
325 scales: Tensor,
326 g_idx: Option<Tensor>,
327 bias: Option<Tensor>,
328 workspace: Option<Tensor>,
329 is_marlin: bool,
330 is_awq: bool,
331 },
332 Gguf {
333 q_weight: Arc<QTensor>,
334 b: Option<Tensor>,
335 },
336 Unquantized(Linear),
337 Hqq {
338 tensor: Tensor,
339 bits: HqqBits,
340 group_size: NonZeroUsize,
341 axis: HqqAxis,
342 optimization_steps: Option<usize>,
343 round_zeros: Option<bool>,
344 channel_wise: Option<bool>,
345 bias: Option<Tensor>,
346 },
347 Dummy,
348 FP8 {
349 lin: Linear,
350 dtype: DType,
351 },
352 Bnb {
353 weight: Tensor,
354 bias: Option<Tensor>,
355 params: BnbQuantParams,
356 quant_ty: BnbQuantType,
357 },
358 BlockwiseFP8 {
359 weight: Tensor,
360 weight_scale_inv: Tensor,
361 bias: Option<Tensor>,
362 dequant_dtype: DType,
363 weight_block_size: Vec<usize>,
364 },
365 Afq {
366 weight: Tensor,
367 bias: Option<Tensor>,
368 bits: AfqBits,
369 group_size: AfqGroupSize,
370 },
371 MXFP4 {
372 blocks: Tensor,
373 scales: Tensor,
374 bias: Option<Tensor>,
375 },
376}
377
378pub struct MatMul;
381
382impl MatMul {
383 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
385 #[cfg(feature = "accelerate")]
386 {
387 let original_dtype = a.dtype();
388 a.to_dtype(DType::F32)?
389 .matmul(&b.to_dtype(DType::F32)?)?
390 .to_dtype(original_dtype)
391 }
392 #[cfg(not(feature = "accelerate"))]
393 {
394 if a.device().is_cpu() {
395 let original_dtype = a.dtype();
396 a.to_dtype(DType::F16)?
397 .matmul(&b.to_dtype(DType::F16)?)?
398 .to_dtype(original_dtype)
399 } else {
400 a.matmul(b)
401 }
402 }
403 }
404
405 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
408 self.matmul(a, b)? / scale
410 }
411
412 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
415 self.matmul(a, b)? * scale
417 }
418
419 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
421 matmul.forward(x)
422 }
423
424 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
426 matmul.forward(x)
427 }
428}
429
430pub struct Convolution;
433
434impl Convolution {
435 pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
436 if x.device().is_cpu() {
437 let original_dtype = x.dtype();
438 Conv1d::new(
439 layer.weight().to_dtype(DType::F32)?,
440 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
441 *layer.config(),
442 )
443 .forward(&x.to_dtype(DType::F32)?)?
444 .to_dtype(original_dtype)
445 } else {
446 layer.forward(x)
447 }
448 }
449
450 pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
451 if x.device().is_cpu() {
452 let original_dtype = x.dtype();
453 Conv2d::new(
454 layer.weight().to_dtype(DType::F32)?,
455 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
456 *layer.config(),
457 )
458 .forward(&x.to_dtype(DType::F32)?)?
459 .to_dtype(original_dtype)
460 } else {
461 layer.forward(x)
462 }
463 }
464}
465
466#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
467pub enum IsqType {
468 Q4_0,
469 Q4_1,
470 Q5_0,
471 Q5_1,
472 Q8_0,
473 Q8_1,
474 Q2K,
475 Q3K,
476 Q4K,
477 Q5K,
478 Q6K,
479 Q8K,
480 HQQ8,
481 HQQ4,
482 F8E4M3,
486 AFQ8,
487 AFQ6,
488 AFQ4,
489 AFQ3,
490 AFQ2,
491}
492
493impl IsqType {
494 pub fn pack_factor(&self, dtype: DType) -> usize {
497 match self {
498 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
499 .div_ceil(GgmlDType::Q4_0.type_size()),
500 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
501 .div_ceil(GgmlDType::Q4_1.type_size()),
502 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
503 .div_ceil(GgmlDType::Q5_0.type_size()),
504 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
505 .div_ceil(GgmlDType::Q5_1.type_size()),
506 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
507 .div_ceil(GgmlDType::Q8_0.type_size()),
508 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
509 .div_ceil(GgmlDType::Q8_1.type_size()),
510 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
511 .div_ceil(GgmlDType::Q2K.type_size()),
512 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
513 .div_ceil(GgmlDType::Q3K.type_size()),
514 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
515 .div_ceil(GgmlDType::Q4K.type_size()),
516 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
517 .div_ceil(GgmlDType::Q5K.type_size()),
518 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
519 .div_ceil(GgmlDType::Q6K.type_size()),
520 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
521 .div_ceil(GgmlDType::Q8K.type_size()),
522 Self::HQQ4 => 4,
524 Self::HQQ8 => 2,
525 Self::F8E4M3 => 2,
526 }
527 }
528
529 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
530 match self {
531 IsqType::HQQ4
533 | IsqType::HQQ8
534 | IsqType::AFQ2
535 | IsqType::AFQ3
536 | IsqType::AFQ4
537 | IsqType::AFQ6
538 | IsqType::AFQ8 => {
539 Some(1.try_into().unwrap())
541 }
542 IsqType::F8E4M3 => None,
543 IsqType::Q2K
544 | IsqType::Q3K
545 | IsqType::Q4K
546 | IsqType::Q4_0
547 | IsqType::Q4_1
548 | IsqType::Q5K
549 | IsqType::Q5_0
550 | IsqType::Q5_1
551 | IsqType::Q6K
552 | IsqType::Q8K
553 | IsqType::Q8_0
554 | IsqType::Q8_1 => None,
555 }
556 }
557}
558
559impl TryFrom<IsqType> for GgmlDType {
560 type Error = candle_core::Error;
561
562 fn try_from(value: IsqType) -> Result<Self> {
563 let tp = match value {
564 IsqType::Q2K => Self::Q2K,
565 IsqType::Q3K => Self::Q3K,
566 IsqType::Q4K => Self::Q4K,
567 IsqType::Q4_0 => Self::Q4_0,
568 IsqType::Q4_1 => Self::Q4_1,
569 IsqType::Q5K => Self::Q5K,
570 IsqType::Q5_0 => Self::Q5_0,
571 IsqType::Q5_1 => Self::Q5_1,
572 IsqType::Q6K => Self::Q6K,
573 IsqType::Q8K => Self::Q8K,
574 IsqType::Q8_0 => Self::Q8_0,
575 IsqType::Q8_1 => Self::Q8_1,
576 _ => candle_core::bail!("Expected valid GGML ISQ type."),
577 };
578 #[cfg(feature = "cuda")]
579 {
580 if !matches!(
581 tp,
582 GgmlDType::Q4_0
583 | GgmlDType::Q4_1
584 | GgmlDType::Q5_0
585 | GgmlDType::Q5_1
586 | GgmlDType::Q8_0
587 | GgmlDType::Q2K
588 | GgmlDType::Q3K
589 | GgmlDType::Q4K
590 | GgmlDType::Q5K
591 | GgmlDType::Q6K
592 ) {
593 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
594 }
595 }
596 Ok(tp)
597 }
598}
599
600impl TryFrom<GgmlDType> for IsqType {
601 type Error = candle_core::Error;
602
603 fn try_from(value: GgmlDType) -> Result<Self> {
604 match value {
605 GgmlDType::Q2K => Ok(Self::Q2K),
606 GgmlDType::Q3K => Ok(Self::Q3K),
607 GgmlDType::Q4K => Ok(Self::Q4K),
608 GgmlDType::Q5K => Ok(Self::Q5K),
609 GgmlDType::Q6K => Ok(Self::Q6K),
610 GgmlDType::Q4_0 => Ok(Self::Q4_0),
611 GgmlDType::Q4_1 => Ok(Self::Q4_1),
612 GgmlDType::Q5_0 => Ok(Self::Q5_0),
613 GgmlDType::Q5_1 => Ok(Self::Q5_1),
614 GgmlDType::Q8_0 => Ok(Self::Q8_0),
615 GgmlDType::Q8_1 => Ok(Self::Q8_1),
616 GgmlDType::Q8K => Ok(Self::Q8K),
617 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
618 candle_core::bail!("Expected valid GGML ISQ type.")
619 }
620 }
621 }
622}
623
624#[derive(Debug, Clone, Copy)]
625pub enum QuantizedSerdeType {
626 Gguf = 0,
627 Unquant = 1,
628 Hqq = 2,
629 Fp8 = 3,
630 Afq = 4,
631}
632
633impl TryFrom<usize> for QuantizedSerdeType {
634 type Error = candle_core::Error;
635 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
636 match value {
637 0 => Ok(Self::Gguf),
638 1 => Ok(Self::Unquant),
639 2 => Ok(Self::Hqq),
640 3 => Ok(Self::Fp8),
641 4 => Ok(Self::Afq),
642 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
643 }
644 }
645}
646
647pub trait QuantizedSerde {
648 fn name(&self) -> &'static str;
649 fn isq_serde_supported(&self) -> bool {
650 false
651 }
652 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
653 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
654 }
655 fn deserialize(
656 _data: Cow<[u8]>,
657 _device: &Device,
658 _comm: &Arc<crate::Comm>,
659 _guard: QuantizeOntoGuard,
660 ) -> Result<Arc<dyn QuantMethod>>
661 where
662 Self: Sized,
663 {
664 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
665 }
666 fn deserialize_ext_bias(
667 _data: Cow<[u8]>,
668 _device: &Device,
669 _guard: QuantizeOntoGuard,
670 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
671 where
672 Self: Sized,
673 {
674 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
675 }
676 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
678 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
679 }
680}
681
682#[derive(Clone, Debug)]
684#[allow(unused)]
685pub struct QuantizeOntoGuard {
686 pub inner: Arc<Mutex<()>>,
687}
688
689pub enum QuantizeOntoDropGuard<'a> {
691 Real(MutexGuard<'a, ()>),
692 Fake,
693}
694
695impl Default for QuantizeOntoGuard {
696 fn default() -> Self {
697 Self::new()
698 }
699}
700
701impl QuantizeOntoGuard {
702 pub fn new() -> Self {
703 QuantizeOntoGuard {
704 inner: Arc::new(Mutex::new(())),
705 }
706 }
707
708 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
712 #[cfg(feature = "cuda")]
713 {
714 let _ = device;
715 QuantizeOntoDropGuard::Fake
716 }
717
718 #[cfg(not(feature = "cuda"))]
719 {
720 #[cfg(feature = "metal")]
721 if let Device::Metal(dev) = device {
722 dev.wait_until_completed()
724 .expect("Failed to flush command buffer.");
725 }
726 #[cfg(not(feature = "metal"))]
727 let _ = device;
728
729 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
730 }
731 }
732}
733
734pub enum DistributedKind {
735 ColumnParallel,
736 RowParallel,
737 Replicated,
738}
739
740pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
742 fn new(method: QuantMethodConfig) -> Result<Self>
743 where
744 Self: Sized;
745
746 fn dequantize_w(&self) -> Result<Tensor>;
747
748 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
751 let original_ty = a.dtype();
752 let a = if let Some(t) = self.quantized_act_type() {
753 a.to_dtype(t)?
754 } else {
755 a.clone()
756 };
757 self.forward(&a)?.to_dtype(original_ty)
758 }
759
760 fn forward(&self, a: &Tensor) -> Result<Tensor>;
762
763 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
769 let original_ty = a.dtype();
770 let a = if let Some(t) = self.quantized_act_type() {
771 a.to_dtype(t)?
772 } else {
773 a.clone()
774 };
775 self.gather_forward(&a, indices)?.to_dtype(original_ty)
776 }
777
778 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
783 candle_core::bail!(
784 "{} does not support `gather_forward`. Please raise an issue.",
785 self.name()
786 )
787 }
788
789 fn quantized_act_type(&self) -> Option<DType>;
791
792 fn dtype_and_device(&self) -> (DType, Device);
794
795 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
797
798 fn apply_isq(
800 self: Arc<Self>,
801 dtype: Option<IsqType>,
802 device: Device,
803 n_quantized: &AtomicUsize,
804 imatrix_weight: Option<Vec<f32>>,
805 guard: QuantizeOntoGuard,
806 ) -> Result<Arc<dyn QuantMethod>>;
807
808 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
809 None
810 }
811
812 fn begin_track_stats(&mut self) -> Result<()> {
814 candle_core::bail!("`{}` does not support tracking stats.", self.name())
815 }
816
817 fn end_track_stats(&self) -> Result<Tensor> {
819 candle_core::bail!("`{}` does not support tracking stats.", self.name())
820 }
821
822 fn is_distributed(&self) -> Option<DistributedKind> {
823 None
824 }
825}
826
827impl Module for dyn QuantMethod {
828 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
829 Self::forward(self, xs)
830 }
831}
832
833pub fn linear_no_bias(
834 in_dim: usize,
835 out_dim: usize,
836 config: &Option<QuantizedConfig>,
837 vb: ShardedVarBuilder,
838) -> Result<Arc<dyn QuantMethod>> {
839 let base_vb = vb.clone();
840 let vb = if should_apply_immediate_isq(&vb) {
841 vb.set_device(Device::Cpu)
842 } else {
843 vb
844 };
845
846 let layer = if let Some(quant_conf) = &config {
847 match quant_conf {
848 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
849 QuantizedConfig::Fp8 { .. } => {
850 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
851 }
852 QuantizedConfig::Bitsandbytes { .. } => {
853 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
854 }
855 QuantizedConfig::Afq { .. } => {
856 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
857 }
858 QuantizedConfig::MXFP4 {} => {
859 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
860 }
861 }
862 } else {
863 if !vb.contains_tensor("weight") {
865 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
866 Arc::new(layer) as Arc<dyn QuantMethod>
867 } else {
868 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
869 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
870
871 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
872 Linear::new(weight, None),
873 ))?;
874 Arc::new(layer) as Arc<dyn QuantMethod>
875 }
876 };
877 apply_immediate_isq(layer, base_vb)
878}
879
880pub fn linear(
881 in_dim: usize,
882 out_dim: usize,
883 config: &Option<QuantizedConfig>,
884 vb: ShardedVarBuilder,
885) -> Result<Arc<dyn QuantMethod>> {
886 let base_vb = vb.clone();
887 let vb = if should_apply_immediate_isq(&vb) {
888 vb.set_device(Device::Cpu)
889 } else {
890 vb
891 };
892
893 let layer = if let Some(quant_conf) = &config {
894 match quant_conf {
895 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
896 QuantizedConfig::Fp8 { .. } => {
897 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
898 }
899 QuantizedConfig::Bitsandbytes { .. } => {
900 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
901 }
902 QuantizedConfig::Afq { .. } => {
903 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
904 }
905 QuantizedConfig::MXFP4 {} => {
906 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
907 }
908 }
909 } else {
910 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
912 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
913 Arc::new(layer) as Arc<dyn QuantMethod>
914 } else {
915 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
916 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
917 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
918
919 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
920 Linear::new(weight, Some(bias)),
921 ))?;
922 Arc::new(layer) as Arc<dyn QuantMethod>
923 }
924 };
925 apply_immediate_isq(layer, base_vb)
926}
927
928pub fn linear_b(
929 in_dim: usize,
930 out_dim: usize,
931 bias: bool,
932 config: &Option<QuantizedConfig>,
933 vb: ShardedVarBuilder,
934) -> Result<Arc<dyn QuantMethod>> {
935 if bias {
936 linear(in_dim, out_dim, config, vb)
937 } else {
938 linear_no_bias(in_dim, out_dim, config, vb)
939 }
940}