1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29mod mxfp4;
30pub mod rotary;
31pub mod safetensors;
32mod scalar_fp8;
33mod unquantized;
34mod utils;
35mod vector_fp8;
36
37use gptq::gptq_linear;
38use lora::merge_lora_weights;
39use regex::Regex;
40pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
41
42pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
43pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
44pub use blockwise_fp8::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
45pub use distributed::{
46 layers::{
47 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
48 ReplicatedLayer, RowParallelLayer,
49 },
50 socket::{Client, Server},
51 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
52};
53pub use dummy::DummyLayer;
54pub use fp8::FP8Linear;
55pub use gguf::GgufMatMul;
56pub use gptq::GptqLayer;
57pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
58pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
59pub use lora::{
60 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
61 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
62};
63pub use mxfp4::MXFP4Layer;
64pub use unquantized::UnquantLinear;
65pub use utils::isq::apply_immediate_isq;
66pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
67pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
68
69use candle_nn::{Conv1d, Conv2d, Linear, Module};
70use serde::{Deserialize, Deserializer, Serialize};
71
72#[derive(Clone, Debug)]
73pub struct ImmediateIsqParams {
74 pub guard: QuantizeOntoGuard,
75 pub ty: Option<IsqType>,
76 pub predicates: Vec<Regex>,
77 pub overrides: Vec<ImmediateIsqOverride>,
78}
79
80#[derive(Clone, Debug)]
81pub struct ImmediateIsqOverride {
82 pub predicate: Regex,
83 pub ty: Option<IsqType>,
84 pub device: Option<Device>,
85}
86
87#[derive(Clone, Debug)]
88pub struct ImmediateIsqMatch {
89 pub ty: IsqType,
90 pub device: Option<Device>,
91}
92
93thread_local! {
94 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
95}
96
97pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
98 set_immediate_isq_with_overrides(isq, predicates, Vec::new());
99}
100
101pub fn set_immediate_isq_with_overrides(
102 isq: Option<IsqType>,
103 predicates: Vec<Regex>,
104 overrides: Vec<ImmediateIsqOverride>,
105) {
106 ENGINE_IMMEDIATE_ISQ.with(|cell| {
107 *cell.borrow_mut() = Some(ImmediateIsqParams {
108 guard: QuantizeOntoGuard::new(),
109 ty: isq,
110 predicates,
111 overrides,
112 });
113 });
114}
115
116pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
117 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
118}
119
120pub fn clear_immediate_isq() {
121 ENGINE_IMMEDIATE_ISQ.with(|cell| {
122 *cell.borrow_mut() = None;
123 });
124}
125
126pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
127 immediate_isq_match(vb).is_some()
128}
129
130pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
131 let immediate_isq = get_immediate_isq()?;
132 let prefix = format!("{}.weight", vb.prefix());
134 resolve_immediate_isq(&immediate_isq, &prefix)
135}
136
137fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
138 if let Some(override_hit) = params
139 .overrides
140 .iter()
141 .find(|override_pred| override_pred.predicate.is_match(prefix))
142 {
143 if let Some(ty) = override_hit.ty.or(params.ty) {
144 return Some(ImmediateIsqMatch {
145 ty,
146 device: override_hit.device.clone(),
147 });
148 }
149 return None;
150 }
151
152 if let Some(ty) = params.ty {
153 if params
154 .predicates
155 .iter()
156 .any(|predicate| predicate.is_match(prefix))
157 {
158 return Some(ImmediateIsqMatch { ty, device: None });
159 }
160 }
161
162 None
163}
164
165#[derive(Debug, Clone, Serialize)]
166#[serde(tag = "quant_method", rename_all = "lowercase")]
167pub enum QuantizedConfig {
168 GptqAwq {
169 bits: usize,
170 group_size: usize,
171 checkpoint_format: Option<String>,
172 is_awq: bool,
173 },
174 Fp8 {
175 weight_block_size: Vec<usize>,
176 },
177 Bitsandbytes {
178 bnb_4bit_quant_type: Option<String>,
179 },
180 Afq {
181 bits: usize,
182 group_size: usize,
183 },
184 MXFP4 {},
185}
186
187#[derive(Deserialize)]
189struct RawConfig {
190 quant_method: Option<String>,
191 bits: Option<usize>,
192 group_size: Option<usize>,
193 checkpoint_format: Option<String>,
194 weight_block_size: Option<Vec<usize>>,
195 bnb_4bit_quant_type: Option<String>,
196}
197
198impl<'de> Deserialize<'de> for QuantizedConfig {
200 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
201 where
202 D: Deserializer<'de>,
203 {
204 let raw = RawConfig::deserialize(deserializer)?;
205
206 match &raw.quant_method {
207 Some(m) if m == "gptq" || m == "awq" => {
208 let bits = raw
209 .bits
210 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
211 let group_size = raw
212 .group_size
213 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
214 Ok(QuantizedConfig::GptqAwq {
215 bits,
216 group_size,
217 checkpoint_format: raw.checkpoint_format,
218 is_awq: m == "awq",
219 })
220 }
221 Some(m) if m == "fp8" => {
222 let weight_block_size = raw
223 .weight_block_size
224 .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
225 Ok(QuantizedConfig::Fp8 { weight_block_size })
226 }
227 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
228 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
229 }),
230 Some(m) if m == "afq" => {
231 let bits = raw
232 .bits
233 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
234 let group_size = raw
235 .group_size
236 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
237 Ok(QuantizedConfig::Afq { bits, group_size })
238 }
239 Some(m) if m == "mxfp4" => {
240 Ok(QuantizedConfig::MXFP4 { })
241 }
242 None => {
243 let bits = raw
244 .bits
245 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
246 let group_size = raw
247 .group_size
248 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
249 Ok(QuantizedConfig::Afq { bits, group_size })
250 }
251 Some(unknown_method) => {
252 Err(serde::de::Error::custom(format!(
253 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
254 )))
255 },
256 }
257 }
258}
259
260impl QuantizedConfig {
261 pub fn name(&self) -> &'static str {
262 match self {
263 Self::GptqAwq { .. } => "gptq",
264 Self::Fp8 { .. } => "fp8",
265 Self::Bitsandbytes { .. } => "bitsandbytes",
266 Self::Afq { .. } => "afq",
267 Self::MXFP4 { .. } => "mxfp4",
268 }
269 }
270
271 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
272 match self {
273 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
274 Self::Fp8 { .. } => "8 bits".to_string(),
275 Self::Bitsandbytes {
276 bnb_4bit_quant_type: Some(_),
277 } => "4 bits".to_string(),
278 Self::Bitsandbytes {
279 bnb_4bit_quant_type: None,
280 } => "8 bits".to_string(),
281 Self::Afq { bits, .. } => format!("{bits} bits"),
282 Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
283 }
284 }
285
286 pub fn pack_factor(&self, dtype: DType) -> usize {
287 match self {
288 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
289 2 => IsqType::Q2K.pack_factor(dtype),
290 3 => IsqType::Q3K.pack_factor(dtype),
291 4 => IsqType::Q4K.pack_factor(dtype),
292 5 => IsqType::Q5K.pack_factor(dtype),
293 6 => IsqType::Q6K.pack_factor(dtype),
294 8 => IsqType::Q8_0.pack_factor(dtype),
295 40 => 4, other => panic!("Unexpected bits in `pack_factor` {other}"),
297 },
298 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
299 Self::Bitsandbytes {
300 bnb_4bit_quant_type: Some(_),
301 }
302 | Self::Bitsandbytes {
303 bnb_4bit_quant_type: None,
304 } => IsqType::Q4K.pack_factor(dtype),
305 Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
306 }
307 }
308}
309
310#[derive(Debug, Clone)]
311pub enum QuantMethodConfig {
312 GptqAwq {
313 bits: i32,
314 use_exllama: bool,
315 q_weight: Tensor,
316 qzeros: Option<Tensor>,
317 scales: Tensor,
318 g_idx: Option<Tensor>,
319 bias: Option<Tensor>,
320 workspace: Option<Tensor>,
321 is_marlin: bool,
322 is_awq: bool,
323 },
324 Gguf {
325 q_weight: Arc<QTensor>,
326 b: Option<Tensor>,
327 },
328 Unquantized(Linear),
329 Hqq {
330 tensor: Tensor,
331 bits: HqqBits,
332 group_size: NonZeroUsize,
333 axis: HqqAxis,
334 optimization_steps: Option<usize>,
335 round_zeros: Option<bool>,
336 channel_wise: Option<bool>,
337 bias: Option<Tensor>,
338 },
339 Dummy,
340 FP8 {
341 lin: Linear,
342 dtype: DType,
343 },
344 Bnb {
345 weight: Tensor,
346 bias: Option<Tensor>,
347 params: BnbQuantParams,
348 quant_ty: BnbQuantType,
349 },
350 BlockwiseFP8 {
351 weight: Tensor,
352 weight_scale_inv: Tensor,
353 bias: Option<Tensor>,
354 dequant_dtype: DType,
355 weight_block_size: Vec<usize>,
356 },
357 Afq {
358 weight: Tensor,
359 bias: Option<Tensor>,
360 bits: AfqBits,
361 group_size: AfqGroupSize,
362 },
363 MXFP4 {
364 blocks: Tensor,
365 scales: Tensor,
366 bias: Option<Tensor>,
367 },
368}
369
370pub struct MatMul;
373
374impl MatMul {
375 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
377 #[cfg(feature = "accelerate")]
378 {
379 let original_dtype = a.dtype();
380 a.to_dtype(DType::F32)?
381 .matmul(&b.to_dtype(DType::F32)?)?
382 .to_dtype(original_dtype)
383 }
384 #[cfg(not(feature = "accelerate"))]
385 {
386 if a.device().is_cpu() {
387 let original_dtype = a.dtype();
388 a.to_dtype(DType::F16)?
389 .matmul(&b.to_dtype(DType::F16)?)?
390 .to_dtype(original_dtype)
391 } else {
392 a.matmul(b)
393 }
394 }
395 }
396
397 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
400 self.matmul(a, b)? / scale
402 }
403
404 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
407 self.matmul(a, b)? * scale
409 }
410
411 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
413 matmul.forward(x)
414 }
415
416 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
418 matmul.forward(x)
419 }
420}
421
422pub struct Convolution;
425
426impl Convolution {
427 pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
428 if x.device().is_cpu() {
429 let original_dtype = x.dtype();
430 Conv1d::new(
431 layer.weight().to_dtype(DType::F32)?,
432 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
433 *layer.config(),
434 )
435 .forward(&x.to_dtype(DType::F32)?)?
436 .to_dtype(original_dtype)
437 } else {
438 layer.forward(x)
439 }
440 }
441
442 pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
443 if x.device().is_cpu() {
444 let original_dtype = x.dtype();
445 Conv2d::new(
446 layer.weight().to_dtype(DType::F32)?,
447 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
448 *layer.config(),
449 )
450 .forward(&x.to_dtype(DType::F32)?)?
451 .to_dtype(original_dtype)
452 } else {
453 layer.forward(x)
454 }
455 }
456}
457
458#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
459pub enum IsqType {
460 Q4_0,
461 Q4_1,
462 Q5_0,
463 Q5_1,
464 Q8_0,
465 Q8_1,
466 Q2K,
467 Q3K,
468 Q4K,
469 Q5K,
470 Q6K,
471 Q8K,
472 HQQ8,
473 HQQ4,
474 F8E4M3,
478 AFQ8,
479 AFQ6,
480 AFQ4,
481 AFQ3,
482 AFQ2,
483}
484
485impl IsqType {
486 pub fn pack_factor(&self, dtype: DType) -> usize {
489 match self {
490 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
491 .div_ceil(GgmlDType::Q4_0.type_size()),
492 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
493 .div_ceil(GgmlDType::Q4_1.type_size()),
494 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
495 .div_ceil(GgmlDType::Q5_0.type_size()),
496 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
497 .div_ceil(GgmlDType::Q5_1.type_size()),
498 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
499 .div_ceil(GgmlDType::Q8_0.type_size()),
500 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
501 .div_ceil(GgmlDType::Q8_1.type_size()),
502 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
503 .div_ceil(GgmlDType::Q2K.type_size()),
504 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
505 .div_ceil(GgmlDType::Q3K.type_size()),
506 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
507 .div_ceil(GgmlDType::Q4K.type_size()),
508 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
509 .div_ceil(GgmlDType::Q5K.type_size()),
510 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
511 .div_ceil(GgmlDType::Q6K.type_size()),
512 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
513 .div_ceil(GgmlDType::Q8K.type_size()),
514 Self::HQQ4 => 4,
516 Self::HQQ8 => 2,
517 Self::F8E4M3 => 2,
518 }
519 }
520
521 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
522 match self {
523 IsqType::HQQ4
525 | IsqType::HQQ8
526 | IsqType::AFQ2
527 | IsqType::AFQ3
528 | IsqType::AFQ4
529 | IsqType::AFQ6
530 | IsqType::AFQ8 => {
531 Some(1.try_into().unwrap())
533 }
534 IsqType::F8E4M3 => None,
535 IsqType::Q2K
536 | IsqType::Q3K
537 | IsqType::Q4K
538 | IsqType::Q4_0
539 | IsqType::Q4_1
540 | IsqType::Q5K
541 | IsqType::Q5_0
542 | IsqType::Q5_1
543 | IsqType::Q6K
544 | IsqType::Q8K
545 | IsqType::Q8_0
546 | IsqType::Q8_1 => None,
547 }
548 }
549}
550
551impl TryFrom<IsqType> for GgmlDType {
552 type Error = candle_core::Error;
553
554 fn try_from(value: IsqType) -> Result<Self> {
555 let tp = match value {
556 IsqType::Q2K => Self::Q2K,
557 IsqType::Q3K => Self::Q3K,
558 IsqType::Q4K => Self::Q4K,
559 IsqType::Q4_0 => Self::Q4_0,
560 IsqType::Q4_1 => Self::Q4_1,
561 IsqType::Q5K => Self::Q5K,
562 IsqType::Q5_0 => Self::Q5_0,
563 IsqType::Q5_1 => Self::Q5_1,
564 IsqType::Q6K => Self::Q6K,
565 IsqType::Q8K => Self::Q8K,
566 IsqType::Q8_0 => Self::Q8_0,
567 IsqType::Q8_1 => Self::Q8_1,
568 _ => candle_core::bail!("Expected valid GGML ISQ type."),
569 };
570 #[cfg(feature = "cuda")]
571 {
572 if !matches!(
573 tp,
574 GgmlDType::Q4_0
575 | GgmlDType::Q4_1
576 | GgmlDType::Q5_0
577 | GgmlDType::Q5_1
578 | GgmlDType::Q8_0
579 | GgmlDType::Q2K
580 | GgmlDType::Q3K
581 | GgmlDType::Q4K
582 | GgmlDType::Q5K
583 | GgmlDType::Q6K
584 ) {
585 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
586 }
587 }
588 Ok(tp)
589 }
590}
591
592impl TryFrom<GgmlDType> for IsqType {
593 type Error = candle_core::Error;
594
595 fn try_from(value: GgmlDType) -> Result<Self> {
596 match value {
597 GgmlDType::Q2K => Ok(Self::Q2K),
598 GgmlDType::Q3K => Ok(Self::Q3K),
599 GgmlDType::Q4K => Ok(Self::Q4K),
600 GgmlDType::Q5K => Ok(Self::Q5K),
601 GgmlDType::Q6K => Ok(Self::Q6K),
602 GgmlDType::Q4_0 => Ok(Self::Q4_0),
603 GgmlDType::Q4_1 => Ok(Self::Q4_1),
604 GgmlDType::Q5_0 => Ok(Self::Q5_0),
605 GgmlDType::Q5_1 => Ok(Self::Q5_1),
606 GgmlDType::Q8_0 => Ok(Self::Q8_0),
607 GgmlDType::Q8_1 => Ok(Self::Q8_1),
608 GgmlDType::Q8K => Ok(Self::Q8K),
609 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
610 candle_core::bail!("Expected valid GGML ISQ type.")
611 }
612 }
613 }
614}
615
616#[derive(Debug, Clone, Copy)]
617pub enum QuantizedSerdeType {
618 Gguf = 0,
619 Unquant = 1,
620 Hqq = 2,
621 Fp8 = 3,
622 Afq = 4,
623}
624
625impl TryFrom<usize> for QuantizedSerdeType {
626 type Error = candle_core::Error;
627 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
628 match value {
629 0 => Ok(Self::Gguf),
630 1 => Ok(Self::Unquant),
631 2 => Ok(Self::Hqq),
632 3 => Ok(Self::Fp8),
633 4 => Ok(Self::Afq),
634 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
635 }
636 }
637}
638
639pub trait QuantizedSerde {
640 fn name(&self) -> &'static str;
641 fn isq_serde_supported(&self) -> bool {
642 false
643 }
644 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
645 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
646 }
647 fn deserialize(
648 _data: Cow<[u8]>,
649 _device: &Device,
650 _comm: &Arc<crate::Comm>,
651 _guard: QuantizeOntoGuard,
652 ) -> Result<Arc<dyn QuantMethod>>
653 where
654 Self: Sized,
655 {
656 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
657 }
658 fn deserialize_ext_bias(
659 _data: Cow<[u8]>,
660 _device: &Device,
661 _guard: QuantizeOntoGuard,
662 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
663 where
664 Self: Sized,
665 {
666 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
667 }
668 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
670 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
671 }
672}
673
674#[derive(Clone, Debug)]
676#[allow(unused)]
677pub struct QuantizeOntoGuard {
678 pub inner: Arc<Mutex<()>>,
679}
680
681pub enum QuantizeOntoDropGuard<'a> {
683 Real(MutexGuard<'a, ()>),
684 Fake,
685}
686
687impl Default for QuantizeOntoGuard {
688 fn default() -> Self {
689 Self::new()
690 }
691}
692
693impl QuantizeOntoGuard {
694 pub fn new() -> Self {
695 QuantizeOntoGuard {
696 inner: Arc::new(Mutex::new(())),
697 }
698 }
699
700 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
704 #[cfg(feature = "cuda")]
705 {
706 let _ = device;
707 QuantizeOntoDropGuard::Fake
708 }
709
710 #[cfg(not(feature = "cuda"))]
711 {
712 #[cfg(feature = "metal")]
713 if let Device::Metal(dev) = device {
714 dev.flush_command_buffer()
716 .expect("Failed to flush command buffer.");
717 }
718 #[cfg(not(feature = "metal"))]
719 let _ = device;
720
721 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
722 }
723 }
724}
725
726pub enum DistributedKind {
727 ColumnParallel,
728 RowParallel,
729 Replicated,
730}
731
732pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
734 fn new(method: QuantMethodConfig) -> Result<Self>
735 where
736 Self: Sized;
737
738 fn dequantize_w(&self) -> Result<Tensor>;
739
740 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
743 let original_ty = a.dtype();
744 let a = if let Some(t) = self.quantized_act_type() {
745 a.to_dtype(t)?
746 } else {
747 a.clone()
748 };
749 self.forward(&a)?.to_dtype(original_ty)
750 }
751
752 fn forward(&self, a: &Tensor) -> Result<Tensor>;
754
755 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
761 let original_ty = a.dtype();
762 let a = if let Some(t) = self.quantized_act_type() {
763 a.to_dtype(t)?
764 } else {
765 a.clone()
766 };
767 self.gather_forward(&a, indices)?.to_dtype(original_ty)
768 }
769
770 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
775 candle_core::bail!(
776 "{} does not support `gather_forward`. Please raise an issue.",
777 self.name()
778 )
779 }
780
781 fn quantized_act_type(&self) -> Option<DType>;
783
784 fn dtype_and_device(&self) -> (DType, Device);
786
787 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
789
790 fn apply_isq(
792 self: Arc<Self>,
793 dtype: Option<IsqType>,
794 device: Device,
795 n_quantized: &AtomicUsize,
796 imatrix_weight: Option<Vec<f32>>,
797 guard: QuantizeOntoGuard,
798 ) -> Result<Arc<dyn QuantMethod>>;
799
800 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
801 None
802 }
803
804 fn begin_track_stats(&mut self) -> Result<()> {
806 candle_core::bail!("`{}` does not support tracking stats.", self.name())
807 }
808
809 fn end_track_stats(&self) -> Result<Tensor> {
811 candle_core::bail!("`{}` does not support tracking stats.", self.name())
812 }
813
814 fn is_distributed(&self) -> Option<DistributedKind> {
815 None
816 }
817}
818
819impl Module for dyn QuantMethod {
820 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
821 Self::forward(self, xs)
822 }
823}
824
825pub fn linear_no_bias(
826 in_dim: usize,
827 out_dim: usize,
828 config: &Option<QuantizedConfig>,
829 vb: ShardedVarBuilder,
830) -> Result<Arc<dyn QuantMethod>> {
831 let base_vb = vb.clone();
832 let vb = if should_apply_immediate_isq(&vb) {
833 vb.set_device(Device::Cpu)
834 } else {
835 vb
836 };
837
838 let layer = if let Some(quant_conf) = &config {
839 match quant_conf {
840 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
841 QuantizedConfig::Fp8 { .. } => {
842 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
843 }
844 QuantizedConfig::Bitsandbytes { .. } => {
845 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
846 }
847 QuantizedConfig::Afq { .. } => {
848 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
849 }
850 QuantizedConfig::MXFP4 {} => {
851 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
852 }
853 }
854 } else {
855 if !vb.contains_tensor("weight") {
857 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
858 Arc::new(layer) as Arc<dyn QuantMethod>
859 } else {
860 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
861 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
862
863 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
864 Linear::new(weight, None),
865 ))?;
866 Arc::new(layer) as Arc<dyn QuantMethod>
867 }
868 };
869 apply_immediate_isq(layer, base_vb)
870}
871
872pub fn linear(
873 in_dim: usize,
874 out_dim: usize,
875 config: &Option<QuantizedConfig>,
876 vb: ShardedVarBuilder,
877) -> Result<Arc<dyn QuantMethod>> {
878 let base_vb = vb.clone();
879 let vb = if should_apply_immediate_isq(&vb) {
880 vb.set_device(Device::Cpu)
881 } else {
882 vb
883 };
884
885 let layer = if let Some(quant_conf) = &config {
886 match quant_conf {
887 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
888 QuantizedConfig::Fp8 { .. } => {
889 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
890 }
891 QuantizedConfig::Bitsandbytes { .. } => {
892 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
893 }
894 QuantizedConfig::Afq { .. } => {
895 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
896 }
897 QuantizedConfig::MXFP4 {} => {
898 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
899 }
900 }
901 } else {
902 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
904 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
905 Arc::new(layer) as Arc<dyn QuantMethod>
906 } else {
907 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
908 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
909 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
910
911 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
912 Linear::new(weight, Some(bias)),
913 ))?;
914 Arc::new(layer) as Arc<dyn QuantMethod>
915 }
916 };
917 apply_immediate_isq(layer, base_vb)
918}
919
920pub fn linear_b(
921 in_dim: usize,
922 out_dim: usize,
923 bias: bool,
924 config: &Option<QuantizedConfig>,
925 vb: ShardedVarBuilder,
926) -> Result<Arc<dyn QuantMethod>> {
927 if bias {
928 linear(in_dim, out_dim, config, vb)
929 } else {
930 linear_no_bias(in_dim, out_dim, config, vb)
931 }
932}