1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24pub mod gemv;
25mod gguf;
26mod gptq;
27mod hqq;
28mod imatrix;
29mod lora;
30mod mxfp4;
31pub mod rotary;
32pub mod safetensors;
33mod scalar_fp8;
34mod unquantized;
35mod utils;
36mod vector_fp8;
37
38use gptq::gptq_linear;
39use lora::merge_lora_weights;
40use regex::Regex;
41pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
42
43pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
44pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
45pub use blockwise_fp8::{
46 blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
47};
48pub use distributed::{
49 layers::{
50 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
51 ReplicatedLayer, RowParallelLayer,
52 },
53 socket::{Client, Server},
54 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
55};
56pub use dummy::DummyLayer;
57pub use fp8::FP8Linear;
58#[cfg(feature = "cuda")]
59pub use gemv::gemv;
60pub use gemv::{should_use_gemv, GEMV_CONTROLLER};
61pub use gguf::GgufMatMul;
62pub use gptq::GptqLayer;
63pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
64pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
65pub use lora::{
66 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
67 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
68};
69pub use mxfp4::MXFP4Layer;
70pub use unquantized::UnquantLinear;
71#[cfg(feature = "cuda")]
72pub use utils::gptoss_swiglu_fused;
73#[cfg(feature = "cuda")]
74pub use utils::gptoss_swiglu_interleaved;
75pub use utils::isq::apply_immediate_isq;
76#[cfg(feature = "cuda")]
77pub use utils::softmax_with_sinks;
78pub use utils::{fused_glu, GluActivationType};
79pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
80pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
81
82use candle_nn::{Conv1d, Conv2d, Linear, Module};
83use serde::{Deserialize, Deserializer, Serialize};
84
85#[derive(Clone, Debug)]
86pub struct ImmediateIsqParams {
87 pub guard: QuantizeOntoGuard,
88 pub ty: Option<IsqType>,
89 pub predicates: Vec<Regex>,
90 pub overrides: Vec<ImmediateIsqOverride>,
91}
92
93#[derive(Clone, Debug)]
94pub struct ImmediateIsqOverride {
95 pub predicate: Regex,
96 pub ty: Option<IsqType>,
97 pub device: Option<Device>,
98}
99
100#[derive(Clone, Debug)]
101pub struct ImmediateIsqMatch {
102 pub ty: IsqType,
103 pub device: Option<Device>,
104}
105
106thread_local! {
107 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
108}
109
110pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
111 set_immediate_isq_with_overrides(isq, predicates, Vec::new());
112}
113
114pub fn set_immediate_isq_with_overrides(
115 isq: Option<IsqType>,
116 predicates: Vec<Regex>,
117 overrides: Vec<ImmediateIsqOverride>,
118) {
119 ENGINE_IMMEDIATE_ISQ.with(|cell| {
120 *cell.borrow_mut() = Some(ImmediateIsqParams {
121 guard: QuantizeOntoGuard::new(),
122 ty: isq,
123 predicates,
124 overrides,
125 });
126 });
127}
128
129pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
130 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
131}
132
133pub fn clear_immediate_isq() {
134 ENGINE_IMMEDIATE_ISQ.with(|cell| {
135 *cell.borrow_mut() = None;
136 });
137}
138
139pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
140 immediate_isq_match(vb).is_some()
141}
142
143pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
144 let immediate_isq = get_immediate_isq()?;
145 let prefix = format!("{}.weight", vb.prefix());
147 resolve_immediate_isq(&immediate_isq, &prefix)
148}
149
150fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
151 if let Some(override_hit) = params
152 .overrides
153 .iter()
154 .find(|override_pred| override_pred.predicate.is_match(prefix))
155 {
156 if let Some(ty) = override_hit.ty.or(params.ty) {
157 return Some(ImmediateIsqMatch {
158 ty,
159 device: override_hit.device.clone(),
160 });
161 }
162 return None;
163 }
164
165 if let Some(ty) = params.ty {
166 if params
167 .predicates
168 .iter()
169 .any(|predicate| predicate.is_match(prefix))
170 {
171 return Some(ImmediateIsqMatch { ty, device: None });
172 }
173 }
174
175 None
176}
177
178#[derive(Debug, Clone, Serialize)]
179#[serde(tag = "quant_method", rename_all = "lowercase")]
180pub enum QuantizedConfig {
181 GptqAwq {
182 bits: usize,
183 group_size: usize,
184 checkpoint_format: Option<String>,
185 is_awq: bool,
186 },
187 Fp8 {
188 weight_block_size: Vec<usize>,
189 },
190 Bitsandbytes {
191 bnb_4bit_quant_type: Option<String>,
192 },
193 Afq {
194 bits: usize,
195 group_size: usize,
196 },
197 MXFP4 {},
198}
199
200#[derive(Deserialize)]
202struct RawConfig {
203 quant_method: Option<String>,
204 bits: Option<usize>,
205 group_size: Option<usize>,
206 checkpoint_format: Option<String>,
207 weight_block_size: Option<Vec<usize>>,
208 bnb_4bit_quant_type: Option<String>,
209}
210
211impl<'de> Deserialize<'de> for QuantizedConfig {
213 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
214 where
215 D: Deserializer<'de>,
216 {
217 let raw = RawConfig::deserialize(deserializer)?;
218
219 match &raw.quant_method {
220 Some(m) if m == "gptq" || m == "awq" => {
221 let bits = raw
222 .bits
223 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
224 let group_size = raw
225 .group_size
226 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
227 Ok(QuantizedConfig::GptqAwq {
228 bits,
229 group_size,
230 checkpoint_format: raw.checkpoint_format,
231 is_awq: m == "awq",
232 })
233 }
234 Some(m) if m == "fp8" => {
235 let weight_block_size = raw
236 .weight_block_size
237 .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
238 Ok(QuantizedConfig::Fp8 { weight_block_size })
239 }
240 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
241 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
242 }),
243 Some(m) if m == "afq" => {
244 let bits = raw
245 .bits
246 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
247 let group_size = raw
248 .group_size
249 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
250 Ok(QuantizedConfig::Afq { bits, group_size })
251 }
252 Some(m) if m == "mxfp4" => {
253 Ok(QuantizedConfig::MXFP4 { })
254 }
255 None => {
256 let bits = raw
257 .bits
258 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
259 let group_size = raw
260 .group_size
261 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
262 Ok(QuantizedConfig::Afq { bits, group_size })
263 }
264 Some(unknown_method) => {
265 Err(serde::de::Error::custom(format!(
266 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
267 )))
268 },
269 }
270 }
271}
272
273impl QuantizedConfig {
274 pub fn name(&self) -> &'static str {
275 match self {
276 Self::GptqAwq { .. } => "gptq",
277 Self::Fp8 { .. } => "fp8",
278 Self::Bitsandbytes { .. } => "bitsandbytes",
279 Self::Afq { .. } => "afq",
280 Self::MXFP4 { .. } => "mxfp4",
281 }
282 }
283
284 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
285 match self {
286 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
287 Self::Fp8 { .. } => "8 bits".to_string(),
288 Self::Bitsandbytes {
289 bnb_4bit_quant_type: Some(_),
290 } => "4 bits".to_string(),
291 Self::Bitsandbytes {
292 bnb_4bit_quant_type: None,
293 } => "8 bits".to_string(),
294 Self::Afq { bits, .. } => format!("{bits} bits"),
295 Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
296 }
297 }
298
299 pub fn pack_factor(&self, dtype: DType) -> usize {
300 match self {
301 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
302 2 => IsqType::Q2K.pack_factor(dtype),
303 3 => IsqType::Q3K.pack_factor(dtype),
304 4 => IsqType::Q4K.pack_factor(dtype),
305 5 => IsqType::Q5K.pack_factor(dtype),
306 6 => IsqType::Q6K.pack_factor(dtype),
307 8 => IsqType::Q8_0.pack_factor(dtype),
308 40 => 4, other => panic!("Unexpected bits in `pack_factor` {other}"),
310 },
311 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
312 Self::Bitsandbytes {
313 bnb_4bit_quant_type: Some(_),
314 }
315 | Self::Bitsandbytes {
316 bnb_4bit_quant_type: None,
317 } => IsqType::Q4K.pack_factor(dtype),
318 Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
319 }
320 }
321}
322
323#[derive(Debug, Clone)]
324pub enum QuantMethodConfig {
325 GptqAwq {
326 bits: i32,
327 use_exllama: bool,
328 q_weight: Tensor,
329 qzeros: Option<Tensor>,
330 scales: Tensor,
331 g_idx: Option<Tensor>,
332 bias: Option<Tensor>,
333 workspace: Option<Tensor>,
334 is_marlin: bool,
335 is_awq: bool,
336 },
337 Gguf {
338 q_weight: Arc<QTensor>,
339 b: Option<Tensor>,
340 },
341 Unquantized(Linear),
342 Hqq {
343 tensor: Tensor,
344 bits: HqqBits,
345 group_size: NonZeroUsize,
346 axis: HqqAxis,
347 optimization_steps: Option<usize>,
348 round_zeros: Option<bool>,
349 channel_wise: Option<bool>,
350 bias: Option<Tensor>,
351 },
352 Dummy,
353 FP8 {
354 lin: Linear,
355 dtype: DType,
356 },
357 Bnb {
358 weight: Tensor,
359 bias: Option<Tensor>,
360 params: BnbQuantParams,
361 quant_ty: BnbQuantType,
362 },
363 BlockwiseFP8 {
364 weight: Tensor,
365 weight_scale_inv: Tensor,
366 bias: Option<Tensor>,
367 dequant_dtype: DType,
368 weight_block_size: Vec<usize>,
369 },
370 Afq {
371 weight: Tensor,
372 bias: Option<Tensor>,
373 bits: AfqBits,
374 group_size: AfqGroupSize,
375 },
376 MXFP4 {
377 blocks: Tensor,
378 scales: Tensor,
379 bias: Option<Tensor>,
380 },
381}
382
383pub struct MatMul;
386
387impl MatMul {
388 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
390 #[cfg(feature = "accelerate")]
391 {
392 let original_dtype = a.dtype();
393 a.to_dtype(DType::F32)?
394 .matmul(&b.to_dtype(DType::F32)?)?
395 .to_dtype(original_dtype)
396 }
397 #[cfg(not(feature = "accelerate"))]
398 {
399 if a.device().is_cpu() {
400 let original_dtype = a.dtype();
401 a.to_dtype(DType::F16)?
402 .matmul(&b.to_dtype(DType::F16)?)?
403 .to_dtype(original_dtype)
404 } else {
405 a.matmul(b)
406 }
407 }
408 }
409
410 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
413 self.matmul(a, b)? / scale
415 }
416
417 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
420 self.matmul(a, b)? * scale
422 }
423
424 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
426 matmul.forward(x)
427 }
428
429 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
431 matmul.forward(x)
432 }
433}
434
435pub struct Convolution;
438
439impl Convolution {
440 pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
441 if x.device().is_cpu() {
442 let original_dtype = x.dtype();
443 Conv1d::new(
444 layer.weight().to_dtype(DType::F32)?,
445 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
446 *layer.config(),
447 )
448 .forward(&x.to_dtype(DType::F32)?)?
449 .to_dtype(original_dtype)
450 } else {
451 layer.forward(x)
452 }
453 }
454
455 pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
456 if x.device().is_cpu() {
457 let original_dtype = x.dtype();
458 Conv2d::new(
459 layer.weight().to_dtype(DType::F32)?,
460 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
461 *layer.config(),
462 )
463 .forward(&x.to_dtype(DType::F32)?)?
464 .to_dtype(original_dtype)
465 } else {
466 layer.forward(x)
467 }
468 }
469}
470
471#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
472pub enum IsqType {
473 Q4_0,
474 Q4_1,
475 Q5_0,
476 Q5_1,
477 Q8_0,
478 Q8_1,
479 Q2K,
480 Q3K,
481 Q4K,
482 Q5K,
483 Q6K,
484 Q8K,
485 HQQ8,
486 HQQ4,
487 F8E4M3,
491 AFQ8,
492 AFQ6,
493 AFQ4,
494 AFQ3,
495 AFQ2,
496}
497
498impl IsqType {
499 pub fn pack_factor(&self, dtype: DType) -> usize {
502 match self {
503 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
504 .div_ceil(GgmlDType::Q4_0.type_size()),
505 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
506 .div_ceil(GgmlDType::Q4_1.type_size()),
507 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
508 .div_ceil(GgmlDType::Q5_0.type_size()),
509 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
510 .div_ceil(GgmlDType::Q5_1.type_size()),
511 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
512 .div_ceil(GgmlDType::Q8_0.type_size()),
513 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
514 .div_ceil(GgmlDType::Q8_1.type_size()),
515 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
516 .div_ceil(GgmlDType::Q2K.type_size()),
517 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
518 .div_ceil(GgmlDType::Q3K.type_size()),
519 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
520 .div_ceil(GgmlDType::Q4K.type_size()),
521 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
522 .div_ceil(GgmlDType::Q5K.type_size()),
523 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
524 .div_ceil(GgmlDType::Q6K.type_size()),
525 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
526 .div_ceil(GgmlDType::Q8K.type_size()),
527 Self::HQQ4 => 4,
529 Self::HQQ8 => 2,
530 Self::F8E4M3 => 2,
531 }
532 }
533
534 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
535 match self {
536 IsqType::HQQ4
538 | IsqType::HQQ8
539 | IsqType::AFQ2
540 | IsqType::AFQ3
541 | IsqType::AFQ4
542 | IsqType::AFQ6
543 | IsqType::AFQ8 => {
544 Some(1.try_into().unwrap())
546 }
547 IsqType::F8E4M3 => None,
548 IsqType::Q2K
549 | IsqType::Q3K
550 | IsqType::Q4K
551 | IsqType::Q4_0
552 | IsqType::Q4_1
553 | IsqType::Q5K
554 | IsqType::Q5_0
555 | IsqType::Q5_1
556 | IsqType::Q6K
557 | IsqType::Q8K
558 | IsqType::Q8_0
559 | IsqType::Q8_1 => None,
560 }
561 }
562}
563
564impl TryFrom<IsqType> for GgmlDType {
565 type Error = candle_core::Error;
566
567 fn try_from(value: IsqType) -> Result<Self> {
568 let tp = match value {
569 IsqType::Q2K => Self::Q2K,
570 IsqType::Q3K => Self::Q3K,
571 IsqType::Q4K => Self::Q4K,
572 IsqType::Q4_0 => Self::Q4_0,
573 IsqType::Q4_1 => Self::Q4_1,
574 IsqType::Q5K => Self::Q5K,
575 IsqType::Q5_0 => Self::Q5_0,
576 IsqType::Q5_1 => Self::Q5_1,
577 IsqType::Q6K => Self::Q6K,
578 IsqType::Q8K => Self::Q8K,
579 IsqType::Q8_0 => Self::Q8_0,
580 IsqType::Q8_1 => Self::Q8_1,
581 _ => candle_core::bail!("Expected valid GGML ISQ type."),
582 };
583 #[cfg(feature = "cuda")]
584 {
585 if !matches!(
586 tp,
587 GgmlDType::Q4_0
588 | GgmlDType::Q4_1
589 | GgmlDType::Q5_0
590 | GgmlDType::Q5_1
591 | GgmlDType::Q8_0
592 | GgmlDType::Q2K
593 | GgmlDType::Q3K
594 | GgmlDType::Q4K
595 | GgmlDType::Q5K
596 | GgmlDType::Q6K
597 ) {
598 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
599 }
600 }
601 Ok(tp)
602 }
603}
604
605impl TryFrom<GgmlDType> for IsqType {
606 type Error = candle_core::Error;
607
608 fn try_from(value: GgmlDType) -> Result<Self> {
609 match value {
610 GgmlDType::Q2K => Ok(Self::Q2K),
611 GgmlDType::Q3K => Ok(Self::Q3K),
612 GgmlDType::Q4K => Ok(Self::Q4K),
613 GgmlDType::Q5K => Ok(Self::Q5K),
614 GgmlDType::Q6K => Ok(Self::Q6K),
615 GgmlDType::Q4_0 => Ok(Self::Q4_0),
616 GgmlDType::Q4_1 => Ok(Self::Q4_1),
617 GgmlDType::Q5_0 => Ok(Self::Q5_0),
618 GgmlDType::Q5_1 => Ok(Self::Q5_1),
619 GgmlDType::Q8_0 => Ok(Self::Q8_0),
620 GgmlDType::Q8_1 => Ok(Self::Q8_1),
621 GgmlDType::Q8K => Ok(Self::Q8K),
622 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
623 candle_core::bail!("Expected valid GGML ISQ type.")
624 }
625 }
626 }
627}
628
629#[derive(Debug, Clone, Copy)]
630pub enum QuantizedSerdeType {
631 Gguf = 0,
632 Unquant = 1,
633 Hqq = 2,
634 Fp8 = 3,
635 Afq = 4,
636}
637
638impl TryFrom<usize> for QuantizedSerdeType {
639 type Error = candle_core::Error;
640 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
641 match value {
642 0 => Ok(Self::Gguf),
643 1 => Ok(Self::Unquant),
644 2 => Ok(Self::Hqq),
645 3 => Ok(Self::Fp8),
646 4 => Ok(Self::Afq),
647 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
648 }
649 }
650}
651
652pub trait QuantizedSerde {
653 fn name(&self) -> &'static str;
654 fn isq_serde_supported(&self) -> bool {
655 false
656 }
657 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
658 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
659 }
660 fn deserialize(
661 _data: Cow<[u8]>,
662 _device: &Device,
663 _comm: &Arc<crate::Comm>,
664 _guard: QuantizeOntoGuard,
665 ) -> Result<Arc<dyn QuantMethod>>
666 where
667 Self: Sized,
668 {
669 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
670 }
671 fn deserialize_ext_bias(
672 _data: Cow<[u8]>,
673 _device: &Device,
674 _guard: QuantizeOntoGuard,
675 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
676 where
677 Self: Sized,
678 {
679 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
680 }
681 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
683 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
684 }
685}
686
687#[derive(Clone, Debug)]
689#[allow(unused)]
690pub struct QuantizeOntoGuard {
691 pub inner: Arc<Mutex<()>>,
692}
693
694pub enum QuantizeOntoDropGuard<'a> {
696 Real(MutexGuard<'a, ()>),
697 Fake,
698}
699
700impl Default for QuantizeOntoGuard {
701 fn default() -> Self {
702 Self::new()
703 }
704}
705
706impl QuantizeOntoGuard {
707 pub fn new() -> Self {
708 QuantizeOntoGuard {
709 inner: Arc::new(Mutex::new(())),
710 }
711 }
712
713 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
717 #[cfg(feature = "cuda")]
718 {
719 let _ = device;
720 QuantizeOntoDropGuard::Fake
721 }
722
723 #[cfg(not(feature = "cuda"))]
724 {
725 #[cfg(feature = "metal")]
726 if let Device::Metal(dev) = device {
727 dev.wait_until_completed()
729 .expect("Failed to flush command buffer.");
730 }
731 #[cfg(not(feature = "metal"))]
732 let _ = device;
733
734 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
735 }
736 }
737}
738
739pub enum DistributedKind {
740 ColumnParallel,
741 RowParallel,
742 Replicated,
743}
744
745pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
747 fn new(method: QuantMethodConfig) -> Result<Self>
748 where
749 Self: Sized;
750
751 fn dequantize_w(&self) -> Result<Tensor>;
752
753 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
756 let original_ty = a.dtype();
757 let a = if let Some(t) = self.quantized_act_type() {
758 a.to_dtype(t)?
759 } else {
760 a.clone()
761 };
762 self.forward(&a)?.to_dtype(original_ty)
763 }
764
765 fn forward(&self, a: &Tensor) -> Result<Tensor>;
767
768 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
774 let original_ty = a.dtype();
775 let a = if let Some(t) = self.quantized_act_type() {
776 a.to_dtype(t)?
777 } else {
778 a.clone()
779 };
780 self.gather_forward(&a, indices)?.to_dtype(original_ty)
781 }
782
783 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
788 candle_core::bail!(
789 "{} does not support `gather_forward`. Please raise an issue.",
790 self.name()
791 )
792 }
793
794 fn quantized_act_type(&self) -> Option<DType>;
796
797 fn dtype_and_device(&self) -> (DType, Device);
799
800 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
802
803 fn apply_isq(
805 self: Arc<Self>,
806 dtype: Option<IsqType>,
807 device: Device,
808 n_quantized: &AtomicUsize,
809 imatrix_weight: Option<Vec<f32>>,
810 guard: QuantizeOntoGuard,
811 ) -> Result<Arc<dyn QuantMethod>>;
812
813 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
814 None
815 }
816
817 fn begin_track_stats(&mut self) -> Result<()> {
819 candle_core::bail!("`{}` does not support tracking stats.", self.name())
820 }
821
822 fn end_track_stats(&self) -> Result<Tensor> {
824 candle_core::bail!("`{}` does not support tracking stats.", self.name())
825 }
826
827 fn is_distributed(&self) -> Option<DistributedKind> {
828 None
829 }
830}
831
832impl Module for dyn QuantMethod {
833 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
834 Self::forward(self, xs)
835 }
836}
837
838pub fn linear_no_bias(
839 in_dim: usize,
840 out_dim: usize,
841 config: &Option<QuantizedConfig>,
842 vb: ShardedVarBuilder,
843) -> Result<Arc<dyn QuantMethod>> {
844 let base_vb = vb.clone();
845 let vb = if should_apply_immediate_isq(&vb) {
846 vb.set_device(Device::Cpu)
847 } else {
848 vb
849 };
850
851 let layer = if let Some(quant_conf) = &config {
852 match quant_conf {
853 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
854 QuantizedConfig::Fp8 { .. } => {
855 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
856 }
857 QuantizedConfig::Bitsandbytes { .. } => {
858 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
859 }
860 QuantizedConfig::Afq { .. } => {
861 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
862 }
863 QuantizedConfig::MXFP4 {} => {
864 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
865 }
866 }
867 } else {
868 if !vb.contains_tensor("weight") {
870 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
871 Arc::new(layer) as Arc<dyn QuantMethod>
872 } else {
873 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
874 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
875
876 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
877 Linear::new(weight, None),
878 ))?;
879 Arc::new(layer) as Arc<dyn QuantMethod>
880 }
881 };
882 apply_immediate_isq(layer, base_vb)
883}
884
885pub fn linear(
886 in_dim: usize,
887 out_dim: usize,
888 config: &Option<QuantizedConfig>,
889 vb: ShardedVarBuilder,
890) -> Result<Arc<dyn QuantMethod>> {
891 let base_vb = vb.clone();
892 let vb = if should_apply_immediate_isq(&vb) {
893 vb.set_device(Device::Cpu)
894 } else {
895 vb
896 };
897
898 let layer = if let Some(quant_conf) = &config {
899 match quant_conf {
900 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
901 QuantizedConfig::Fp8 { .. } => {
902 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
903 }
904 QuantizedConfig::Bitsandbytes { .. } => {
905 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
906 }
907 QuantizedConfig::Afq { .. } => {
908 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
909 }
910 QuantizedConfig::MXFP4 {} => {
911 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
912 }
913 }
914 } else {
915 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
917 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
918 Arc::new(layer) as Arc<dyn QuantMethod>
919 } else {
920 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
921 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
922 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
923
924 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
925 Linear::new(weight, Some(bias)),
926 ))?;
927 Arc::new(layer) as Arc<dyn QuantMethod>
928 }
929 };
930 apply_immediate_isq(layer, base_vb)
931}
932
933pub fn linear_b(
934 in_dim: usize,
935 out_dim: usize,
936 bias: bool,
937 config: &Option<QuantizedConfig>,
938 vb: ShardedVarBuilder,
939) -> Result<Arc<dyn QuantMethod>> {
940 if bias {
941 linear(in_dim, out_dim, config, vb)
942 } else {
943 linear_no_bias(in_dim, out_dim, config, vb)
944 }
945}