1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
44pub use blockwise_fp8::{
45 blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
46};
47pub use distributed::{
48 layers::{
49 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
50 ReplicatedLayer, RowParallelLayer,
51 },
52 socket::{Client, Server},
53 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
54};
55pub use dummy::DummyLayer;
56pub use fp8::FP8Linear;
57pub use gguf::GgufMatMul;
58pub use gptq::GptqLayer;
59pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
60pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
61pub use lora::{
62 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
63 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
64};
65pub use mxfp4::MXFP4Layer;
66pub use unquantized::UnquantLinear;
67pub use utils::isq::apply_immediate_isq;
68pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
69pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
70
71use candle_nn::{Conv1d, Conv2d, Linear, Module};
72use serde::{Deserialize, Deserializer, Serialize};
73
74#[derive(Clone, Debug)]
75pub struct ImmediateIsqParams {
76 pub guard: QuantizeOntoGuard,
77 pub ty: Option<IsqType>,
78 pub predicates: Vec<Regex>,
79 pub overrides: Vec<ImmediateIsqOverride>,
80}
81
82#[derive(Clone, Debug)]
83pub struct ImmediateIsqOverride {
84 pub predicate: Regex,
85 pub ty: Option<IsqType>,
86 pub device: Option<Device>,
87}
88
89#[derive(Clone, Debug)]
90pub struct ImmediateIsqMatch {
91 pub ty: IsqType,
92 pub device: Option<Device>,
93}
94
95thread_local! {
96 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
97}
98
99pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
100 set_immediate_isq_with_overrides(isq, predicates, Vec::new());
101}
102
103pub fn set_immediate_isq_with_overrides(
104 isq: Option<IsqType>,
105 predicates: Vec<Regex>,
106 overrides: Vec<ImmediateIsqOverride>,
107) {
108 ENGINE_IMMEDIATE_ISQ.with(|cell| {
109 *cell.borrow_mut() = Some(ImmediateIsqParams {
110 guard: QuantizeOntoGuard::new(),
111 ty: isq,
112 predicates,
113 overrides,
114 });
115 });
116}
117
118pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
119 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
120}
121
122pub fn clear_immediate_isq() {
123 ENGINE_IMMEDIATE_ISQ.with(|cell| {
124 *cell.borrow_mut() = None;
125 });
126}
127
128pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
129 immediate_isq_match(vb).is_some()
130}
131
132pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
133 let immediate_isq = get_immediate_isq()?;
134 let prefix = format!("{}.weight", vb.prefix());
136 resolve_immediate_isq(&immediate_isq, &prefix)
137}
138
139fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
140 if let Some(override_hit) = params
141 .overrides
142 .iter()
143 .find(|override_pred| override_pred.predicate.is_match(prefix))
144 {
145 if let Some(ty) = override_hit.ty.or(params.ty) {
146 return Some(ImmediateIsqMatch {
147 ty,
148 device: override_hit.device.clone(),
149 });
150 }
151 return None;
152 }
153
154 if let Some(ty) = params.ty {
155 if params
156 .predicates
157 .iter()
158 .any(|predicate| predicate.is_match(prefix))
159 {
160 return Some(ImmediateIsqMatch { ty, device: None });
161 }
162 }
163
164 None
165}
166
167#[derive(Debug, Clone, Serialize)]
168#[serde(tag = "quant_method", rename_all = "lowercase")]
169pub enum QuantizedConfig {
170 GptqAwq {
171 bits: usize,
172 group_size: usize,
173 checkpoint_format: Option<String>,
174 is_awq: bool,
175 },
176 Fp8 {
177 weight_block_size: Vec<usize>,
178 },
179 Bitsandbytes {
180 bnb_4bit_quant_type: Option<String>,
181 },
182 Afq {
183 bits: usize,
184 group_size: usize,
185 },
186 MXFP4 {},
187}
188
189#[derive(Deserialize)]
191struct RawConfig {
192 quant_method: Option<String>,
193 bits: Option<usize>,
194 group_size: Option<usize>,
195 checkpoint_format: Option<String>,
196 weight_block_size: Option<Vec<usize>>,
197 bnb_4bit_quant_type: Option<String>,
198}
199
200impl<'de> Deserialize<'de> for QuantizedConfig {
202 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
203 where
204 D: Deserializer<'de>,
205 {
206 let raw = RawConfig::deserialize(deserializer)?;
207
208 match &raw.quant_method {
209 Some(m) if m == "gptq" || m == "awq" => {
210 let bits = raw
211 .bits
212 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
213 let group_size = raw
214 .group_size
215 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
216 Ok(QuantizedConfig::GptqAwq {
217 bits,
218 group_size,
219 checkpoint_format: raw.checkpoint_format,
220 is_awq: m == "awq",
221 })
222 }
223 Some(m) if m == "fp8" => {
224 let weight_block_size = raw
225 .weight_block_size
226 .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
227 Ok(QuantizedConfig::Fp8 { weight_block_size })
228 }
229 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
230 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
231 }),
232 Some(m) if m == "afq" => {
233 let bits = raw
234 .bits
235 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
236 let group_size = raw
237 .group_size
238 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
239 Ok(QuantizedConfig::Afq { bits, group_size })
240 }
241 Some(m) if m == "mxfp4" => {
242 Ok(QuantizedConfig::MXFP4 { })
243 }
244 None => {
245 let bits = raw
246 .bits
247 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
248 let group_size = raw
249 .group_size
250 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
251 Ok(QuantizedConfig::Afq { bits, group_size })
252 }
253 Some(unknown_method) => {
254 Err(serde::de::Error::custom(format!(
255 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
256 )))
257 },
258 }
259 }
260}
261
262impl QuantizedConfig {
263 pub fn name(&self) -> &'static str {
264 match self {
265 Self::GptqAwq { .. } => "gptq",
266 Self::Fp8 { .. } => "fp8",
267 Self::Bitsandbytes { .. } => "bitsandbytes",
268 Self::Afq { .. } => "afq",
269 Self::MXFP4 { .. } => "mxfp4",
270 }
271 }
272
273 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
274 match self {
275 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
276 Self::Fp8 { .. } => "8 bits".to_string(),
277 Self::Bitsandbytes {
278 bnb_4bit_quant_type: Some(_),
279 } => "4 bits".to_string(),
280 Self::Bitsandbytes {
281 bnb_4bit_quant_type: None,
282 } => "8 bits".to_string(),
283 Self::Afq { bits, .. } => format!("{bits} bits"),
284 Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
285 }
286 }
287
288 pub fn pack_factor(&self, dtype: DType) -> usize {
289 match self {
290 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
291 2 => IsqType::Q2K.pack_factor(dtype),
292 3 => IsqType::Q3K.pack_factor(dtype),
293 4 => IsqType::Q4K.pack_factor(dtype),
294 5 => IsqType::Q5K.pack_factor(dtype),
295 6 => IsqType::Q6K.pack_factor(dtype),
296 8 => IsqType::Q8_0.pack_factor(dtype),
297 40 => 4, other => panic!("Unexpected bits in `pack_factor` {other}"),
299 },
300 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
301 Self::Bitsandbytes {
302 bnb_4bit_quant_type: Some(_),
303 }
304 | Self::Bitsandbytes {
305 bnb_4bit_quant_type: None,
306 } => IsqType::Q4K.pack_factor(dtype),
307 Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
308 }
309 }
310}
311
312#[derive(Debug, Clone)]
313pub enum QuantMethodConfig {
314 GptqAwq {
315 bits: i32,
316 use_exllama: bool,
317 q_weight: Tensor,
318 qzeros: Option<Tensor>,
319 scales: Tensor,
320 g_idx: Option<Tensor>,
321 bias: Option<Tensor>,
322 workspace: Option<Tensor>,
323 is_marlin: bool,
324 is_awq: bool,
325 },
326 Gguf {
327 q_weight: Arc<QTensor>,
328 b: Option<Tensor>,
329 },
330 Unquantized(Linear),
331 Hqq {
332 tensor: Tensor,
333 bits: HqqBits,
334 group_size: NonZeroUsize,
335 axis: HqqAxis,
336 optimization_steps: Option<usize>,
337 round_zeros: Option<bool>,
338 channel_wise: Option<bool>,
339 bias: Option<Tensor>,
340 },
341 Dummy,
342 FP8 {
343 lin: Linear,
344 dtype: DType,
345 },
346 Bnb {
347 weight: Tensor,
348 bias: Option<Tensor>,
349 params: BnbQuantParams,
350 quant_ty: BnbQuantType,
351 },
352 BlockwiseFP8 {
353 weight: Tensor,
354 weight_scale_inv: Tensor,
355 bias: Option<Tensor>,
356 dequant_dtype: DType,
357 weight_block_size: Vec<usize>,
358 },
359 Afq {
360 weight: Tensor,
361 bias: Option<Tensor>,
362 bits: AfqBits,
363 group_size: AfqGroupSize,
364 },
365 MXFP4 {
366 blocks: Tensor,
367 scales: Tensor,
368 bias: Option<Tensor>,
369 },
370}
371
372pub struct MatMul;
375
376impl MatMul {
377 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
379 #[cfg(feature = "accelerate")]
380 {
381 let original_dtype = a.dtype();
382 a.to_dtype(DType::F32)?
383 .matmul(&b.to_dtype(DType::F32)?)?
384 .to_dtype(original_dtype)
385 }
386 #[cfg(not(feature = "accelerate"))]
387 {
388 if a.device().is_cpu() {
389 let original_dtype = a.dtype();
390 a.to_dtype(DType::F16)?
391 .matmul(&b.to_dtype(DType::F16)?)?
392 .to_dtype(original_dtype)
393 } else {
394 a.matmul(b)
395 }
396 }
397 }
398
399 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
402 self.matmul(a, b)? / scale
404 }
405
406 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
409 self.matmul(a, b)? * scale
411 }
412
413 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
415 matmul.forward(x)
416 }
417
418 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
420 matmul.forward(x)
421 }
422}
423
424pub struct Convolution;
427
428impl Convolution {
429 pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
430 if x.device().is_cpu() {
431 let original_dtype = x.dtype();
432 Conv1d::new(
433 layer.weight().to_dtype(DType::F32)?,
434 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
435 *layer.config(),
436 )
437 .forward(&x.to_dtype(DType::F32)?)?
438 .to_dtype(original_dtype)
439 } else {
440 layer.forward(x)
441 }
442 }
443
444 pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
445 if x.device().is_cpu() {
446 let original_dtype = x.dtype();
447 Conv2d::new(
448 layer.weight().to_dtype(DType::F32)?,
449 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
450 *layer.config(),
451 )
452 .forward(&x.to_dtype(DType::F32)?)?
453 .to_dtype(original_dtype)
454 } else {
455 layer.forward(x)
456 }
457 }
458}
459
460#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
461pub enum IsqType {
462 Q4_0,
463 Q4_1,
464 Q5_0,
465 Q5_1,
466 Q8_0,
467 Q8_1,
468 Q2K,
469 Q3K,
470 Q4K,
471 Q5K,
472 Q6K,
473 Q8K,
474 HQQ8,
475 HQQ4,
476 F8E4M3,
480 AFQ8,
481 AFQ6,
482 AFQ4,
483 AFQ3,
484 AFQ2,
485}
486
487impl IsqType {
488 pub fn pack_factor(&self, dtype: DType) -> usize {
491 match self {
492 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
493 .div_ceil(GgmlDType::Q4_0.type_size()),
494 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
495 .div_ceil(GgmlDType::Q4_1.type_size()),
496 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
497 .div_ceil(GgmlDType::Q5_0.type_size()),
498 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
499 .div_ceil(GgmlDType::Q5_1.type_size()),
500 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
501 .div_ceil(GgmlDType::Q8_0.type_size()),
502 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
503 .div_ceil(GgmlDType::Q8_1.type_size()),
504 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
505 .div_ceil(GgmlDType::Q2K.type_size()),
506 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
507 .div_ceil(GgmlDType::Q3K.type_size()),
508 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
509 .div_ceil(GgmlDType::Q4K.type_size()),
510 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
511 .div_ceil(GgmlDType::Q5K.type_size()),
512 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
513 .div_ceil(GgmlDType::Q6K.type_size()),
514 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
515 .div_ceil(GgmlDType::Q8K.type_size()),
516 Self::HQQ4 => 4,
518 Self::HQQ8 => 2,
519 Self::F8E4M3 => 2,
520 }
521 }
522
523 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
524 match self {
525 IsqType::HQQ4
527 | IsqType::HQQ8
528 | IsqType::AFQ2
529 | IsqType::AFQ3
530 | IsqType::AFQ4
531 | IsqType::AFQ6
532 | IsqType::AFQ8 => {
533 Some(1.try_into().unwrap())
535 }
536 IsqType::F8E4M3 => None,
537 IsqType::Q2K
538 | IsqType::Q3K
539 | IsqType::Q4K
540 | IsqType::Q4_0
541 | IsqType::Q4_1
542 | IsqType::Q5K
543 | IsqType::Q5_0
544 | IsqType::Q5_1
545 | IsqType::Q6K
546 | IsqType::Q8K
547 | IsqType::Q8_0
548 | IsqType::Q8_1 => None,
549 }
550 }
551}
552
553impl TryFrom<IsqType> for GgmlDType {
554 type Error = candle_core::Error;
555
556 fn try_from(value: IsqType) -> Result<Self> {
557 let tp = match value {
558 IsqType::Q2K => Self::Q2K,
559 IsqType::Q3K => Self::Q3K,
560 IsqType::Q4K => Self::Q4K,
561 IsqType::Q4_0 => Self::Q4_0,
562 IsqType::Q4_1 => Self::Q4_1,
563 IsqType::Q5K => Self::Q5K,
564 IsqType::Q5_0 => Self::Q5_0,
565 IsqType::Q5_1 => Self::Q5_1,
566 IsqType::Q6K => Self::Q6K,
567 IsqType::Q8K => Self::Q8K,
568 IsqType::Q8_0 => Self::Q8_0,
569 IsqType::Q8_1 => Self::Q8_1,
570 _ => candle_core::bail!("Expected valid GGML ISQ type."),
571 };
572 #[cfg(feature = "cuda")]
573 {
574 if !matches!(
575 tp,
576 GgmlDType::Q4_0
577 | GgmlDType::Q4_1
578 | GgmlDType::Q5_0
579 | GgmlDType::Q5_1
580 | GgmlDType::Q8_0
581 | GgmlDType::Q2K
582 | GgmlDType::Q3K
583 | GgmlDType::Q4K
584 | GgmlDType::Q5K
585 | GgmlDType::Q6K
586 ) {
587 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
588 }
589 }
590 Ok(tp)
591 }
592}
593
594impl TryFrom<GgmlDType> for IsqType {
595 type Error = candle_core::Error;
596
597 fn try_from(value: GgmlDType) -> Result<Self> {
598 match value {
599 GgmlDType::Q2K => Ok(Self::Q2K),
600 GgmlDType::Q3K => Ok(Self::Q3K),
601 GgmlDType::Q4K => Ok(Self::Q4K),
602 GgmlDType::Q5K => Ok(Self::Q5K),
603 GgmlDType::Q6K => Ok(Self::Q6K),
604 GgmlDType::Q4_0 => Ok(Self::Q4_0),
605 GgmlDType::Q4_1 => Ok(Self::Q4_1),
606 GgmlDType::Q5_0 => Ok(Self::Q5_0),
607 GgmlDType::Q5_1 => Ok(Self::Q5_1),
608 GgmlDType::Q8_0 => Ok(Self::Q8_0),
609 GgmlDType::Q8_1 => Ok(Self::Q8_1),
610 GgmlDType::Q8K => Ok(Self::Q8K),
611 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
612 candle_core::bail!("Expected valid GGML ISQ type.")
613 }
614 }
615 }
616}
617
618#[derive(Debug, Clone, Copy)]
619pub enum QuantizedSerdeType {
620 Gguf = 0,
621 Unquant = 1,
622 Hqq = 2,
623 Fp8 = 3,
624 Afq = 4,
625}
626
627impl TryFrom<usize> for QuantizedSerdeType {
628 type Error = candle_core::Error;
629 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
630 match value {
631 0 => Ok(Self::Gguf),
632 1 => Ok(Self::Unquant),
633 2 => Ok(Self::Hqq),
634 3 => Ok(Self::Fp8),
635 4 => Ok(Self::Afq),
636 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
637 }
638 }
639}
640
641pub trait QuantizedSerde {
642 fn name(&self) -> &'static str;
643 fn isq_serde_supported(&self) -> bool {
644 false
645 }
646 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
647 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
648 }
649 fn deserialize(
650 _data: Cow<[u8]>,
651 _device: &Device,
652 _comm: &Arc<crate::Comm>,
653 _guard: QuantizeOntoGuard,
654 ) -> Result<Arc<dyn QuantMethod>>
655 where
656 Self: Sized,
657 {
658 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
659 }
660 fn deserialize_ext_bias(
661 _data: Cow<[u8]>,
662 _device: &Device,
663 _guard: QuantizeOntoGuard,
664 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
665 where
666 Self: Sized,
667 {
668 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
669 }
670 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
672 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
673 }
674}
675
676#[derive(Clone, Debug)]
678#[allow(unused)]
679pub struct QuantizeOntoGuard {
680 pub inner: Arc<Mutex<()>>,
681}
682
683pub enum QuantizeOntoDropGuard<'a> {
685 Real(MutexGuard<'a, ()>),
686 Fake,
687}
688
689impl Default for QuantizeOntoGuard {
690 fn default() -> Self {
691 Self::new()
692 }
693}
694
695impl QuantizeOntoGuard {
696 pub fn new() -> Self {
697 QuantizeOntoGuard {
698 inner: Arc::new(Mutex::new(())),
699 }
700 }
701
702 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
706 #[cfg(feature = "cuda")]
707 {
708 let _ = device;
709 QuantizeOntoDropGuard::Fake
710 }
711
712 #[cfg(not(feature = "cuda"))]
713 {
714 #[cfg(feature = "metal")]
715 if let Device::Metal(dev) = device {
716 dev.wait_until_completed()
718 .expect("Failed to flush command buffer.");
719 }
720 #[cfg(not(feature = "metal"))]
721 let _ = device;
722
723 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
724 }
725 }
726}
727
728pub enum DistributedKind {
729 ColumnParallel,
730 RowParallel,
731 Replicated,
732}
733
734pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
736 fn new(method: QuantMethodConfig) -> Result<Self>
737 where
738 Self: Sized;
739
740 fn dequantize_w(&self) -> Result<Tensor>;
741
742 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
745 let original_ty = a.dtype();
746 let a = if let Some(t) = self.quantized_act_type() {
747 a.to_dtype(t)?
748 } else {
749 a.clone()
750 };
751 self.forward(&a)?.to_dtype(original_ty)
752 }
753
754 fn forward(&self, a: &Tensor) -> Result<Tensor>;
756
757 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
763 let original_ty = a.dtype();
764 let a = if let Some(t) = self.quantized_act_type() {
765 a.to_dtype(t)?
766 } else {
767 a.clone()
768 };
769 self.gather_forward(&a, indices)?.to_dtype(original_ty)
770 }
771
772 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
777 candle_core::bail!(
778 "{} does not support `gather_forward`. Please raise an issue.",
779 self.name()
780 )
781 }
782
783 fn quantized_act_type(&self) -> Option<DType>;
785
786 fn dtype_and_device(&self) -> (DType, Device);
788
789 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
791
792 fn apply_isq(
794 self: Arc<Self>,
795 dtype: Option<IsqType>,
796 device: Device,
797 n_quantized: &AtomicUsize,
798 imatrix_weight: Option<Vec<f32>>,
799 guard: QuantizeOntoGuard,
800 ) -> Result<Arc<dyn QuantMethod>>;
801
802 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
803 None
804 }
805
806 fn begin_track_stats(&mut self) -> Result<()> {
808 candle_core::bail!("`{}` does not support tracking stats.", self.name())
809 }
810
811 fn end_track_stats(&self) -> Result<Tensor> {
813 candle_core::bail!("`{}` does not support tracking stats.", self.name())
814 }
815
816 fn is_distributed(&self) -> Option<DistributedKind> {
817 None
818 }
819}
820
821impl Module for dyn QuantMethod {
822 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
823 Self::forward(self, xs)
824 }
825}
826
827pub fn linear_no_bias(
828 in_dim: usize,
829 out_dim: usize,
830 config: &Option<QuantizedConfig>,
831 vb: ShardedVarBuilder,
832) -> Result<Arc<dyn QuantMethod>> {
833 let base_vb = vb.clone();
834 let vb = if should_apply_immediate_isq(&vb) {
835 vb.set_device(Device::Cpu)
836 } else {
837 vb
838 };
839
840 let layer = if let Some(quant_conf) = &config {
841 match quant_conf {
842 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
843 QuantizedConfig::Fp8 { .. } => {
844 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
845 }
846 QuantizedConfig::Bitsandbytes { .. } => {
847 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
848 }
849 QuantizedConfig::Afq { .. } => {
850 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
851 }
852 QuantizedConfig::MXFP4 {} => {
853 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
854 }
855 }
856 } else {
857 if !vb.contains_tensor("weight") {
859 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
860 Arc::new(layer) as Arc<dyn QuantMethod>
861 } else {
862 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
863 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
864
865 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
866 Linear::new(weight, None),
867 ))?;
868 Arc::new(layer) as Arc<dyn QuantMethod>
869 }
870 };
871 apply_immediate_isq(layer, base_vb)
872}
873
874pub fn linear(
875 in_dim: usize,
876 out_dim: usize,
877 config: &Option<QuantizedConfig>,
878 vb: ShardedVarBuilder,
879) -> Result<Arc<dyn QuantMethod>> {
880 let base_vb = vb.clone();
881 let vb = if should_apply_immediate_isq(&vb) {
882 vb.set_device(Device::Cpu)
883 } else {
884 vb
885 };
886
887 let layer = if let Some(quant_conf) = &config {
888 match quant_conf {
889 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
890 QuantizedConfig::Fp8 { .. } => {
891 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
892 }
893 QuantizedConfig::Bitsandbytes { .. } => {
894 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
895 }
896 QuantizedConfig::Afq { .. } => {
897 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
898 }
899 QuantizedConfig::MXFP4 {} => {
900 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
901 }
902 }
903 } else {
904 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
906 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
907 Arc::new(layer) as Arc<dyn QuantMethod>
908 } else {
909 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
910 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
911 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
912
913 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
914 Linear::new(weight, Some(bias)),
915 ))?;
916 Arc::new(layer) as Arc<dyn QuantMethod>
917 }
918 };
919 apply_immediate_isq(layer, base_vb)
920}
921
922pub fn linear_b(
923 in_dim: usize,
924 out_dim: usize,
925 bias: bool,
926 config: &Option<QuantizedConfig>,
927 vb: ShardedVarBuilder,
928) -> Result<Arc<dyn QuantMethod>> {
929 if bias {
930 linear(in_dim, out_dim, config, vb)
931 } else {
932 linear_no_bias(in_dim, out_dim, config, vb)
933 }
934}