1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13
14#[cfg(feature = "metal")]
15mod metal_kernels;
16
17mod afq;
18mod bitsandbytes;
19mod blockwise_fp8;
20pub mod cublaslt;
21pub mod distributed;
22mod dummy;
23mod fp8;
24mod gguf;
25mod gptq;
26mod hqq;
27mod imatrix;
28mod lora;
29pub mod rotary;
30pub mod safetensors;
31mod unquantized;
32mod utils;
33
34use gptq::gptq_linear;
35use lora::merge_lora_weights;
36pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
37
38pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
39pub use bitsandbytes::{BnbLinear, BnbQuantParmas, BnbQuantType};
40pub use distributed::{
41 layers::{
42 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, ReplicatedLayer,
43 RowParallelLayer,
44 },
45 socket::{Client, Server},
46 BarrierLike, Comm, Id, SumAllReduce,
47};
48pub use dummy::DummyLayer;
49pub use fp8::FP8Linear;
50pub use gguf::GgufMatMul;
51pub use gptq::GptqLayer;
52pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
53pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
54pub use lora::{
55 linear_no_bias_static_lora, LoraAdapter, LoraConfig, StaticLoraConfig, APPLIED_LORAS,
56 MULTI_LORA_DELIMITER,
57};
58pub use unquantized::UnquantLinear;
59pub use utils::UQFF_QUANT_TYPE_OFFSET;
60
61use candle_nn::{Linear, Module};
62use serde::{Deserialize, Deserializer, Serialize};
63
64#[derive(Debug, Clone, Serialize)]
65#[serde(tag = "quant_method", rename_all = "lowercase")]
66pub enum QuantizedConfig {
67 Gptq {
68 bits: usize,
69 group_size: usize,
70 checkpoint_format: Option<String>,
71 },
72 Fp8 {
73 weight_block_size: Vec<usize>,
74 },
75 Bitsandbytes {
76 bnb_4bit_quant_type: Option<String>,
77 },
78 Afq {
79 bits: usize,
80 group_size: usize,
81 },
82}
83
84#[derive(Deserialize)]
86struct RawConfig {
87 quant_method: Option<String>,
88 bits: Option<usize>,
89 group_size: Option<usize>,
90 checkpoint_format: Option<String>,
91 weight_block_size: Option<Vec<usize>>,
92 bnb_4bit_quant_type: Option<String>,
93}
94
95impl<'de> Deserialize<'de> for QuantizedConfig {
97 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
98 where
99 D: Deserializer<'de>,
100 {
101 let raw = RawConfig::deserialize(deserializer)?;
102
103 match &raw.quant_method {
104 Some(m) if m == "gptq" => {
105 let bits = raw
106 .bits
107 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
108 let group_size = raw
109 .group_size
110 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
111 Ok(QuantizedConfig::Gptq {
112 bits,
113 group_size,
114 checkpoint_format: raw.checkpoint_format,
115 })
116 }
117 Some(m) if m == "fp8" => {
118 let weight_block_size = raw
119 .weight_block_size
120 .ok_or_else(|| serde::de::Error::missing_field("weight_block_size"))?;
121 Ok(QuantizedConfig::Fp8 { weight_block_size })
122 }
123 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
124 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
125 }),
126 Some(m) if m == "afq" => {
127 let bits = raw
128 .bits
129 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
130 let group_size = raw
131 .group_size
132 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
133 Ok(QuantizedConfig::Afq { bits, group_size })
134 }
135 None => {
136 let bits = raw
137 .bits
138 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
139 let group_size = raw
140 .group_size
141 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
142 Ok(QuantizedConfig::Afq { bits, group_size })
143 }
144 Some(unknown_method) => {
145 Err(serde::de::Error::custom(format!(
146 "Unknown quantization method: {}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified",
147 unknown_method
148 )))
149 },
150 }
151 }
152}
153
154impl QuantizedConfig {
155 pub fn name(&self) -> &'static str {
156 match self {
157 Self::Gptq { .. } => "gptq",
158 Self::Fp8 { .. } => "fp8",
159 Self::Bitsandbytes { .. } => "bitsandbytes",
160 Self::Afq { .. } => "afq",
161 }
162 }
163
164 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
165 match self {
166 Self::Gptq { bits, .. } => format!("{bits} bits"),
167 Self::Fp8 { .. } => "8 bits".to_string(),
168 Self::Bitsandbytes {
169 bnb_4bit_quant_type: Some(_),
170 } => "4 bits".to_string(),
171 Self::Bitsandbytes {
172 bnb_4bit_quant_type: None,
173 } => "8 bits".to_string(),
174 Self::Afq { bits, .. } => format!("{bits} bits"),
175 }
176 }
177}
178
179#[derive(Debug, Clone)]
180pub enum QuantMethodConfig {
181 Gptq {
182 bits: i32,
183 use_exllama: bool,
184 q_weight: Tensor,
185 gptq_qzeros: Option<Tensor>,
186 gptq_scales: Tensor,
187 g_idx: Option<Tensor>,
188 bias: Option<Tensor>,
189 workspace: Option<Tensor>,
190 is_marlin: bool,
191 },
192 Gguf {
193 q_weight: Arc<QTensor>,
194 b: Option<Tensor>,
195 },
196 Unquantized(Linear),
197 Hqq {
198 tensor: Tensor,
199 bits: HqqBits,
200 group_size: NonZeroUsize,
201 axis: HqqAxis,
202 optimization_steps: Option<usize>,
203 round_zeros: Option<bool>,
204 channel_wise: Option<bool>,
205 bias: Option<Tensor>,
206 },
207 Dummy,
208 FP8 {
209 lin: Linear,
210 dtype: DType,
211 },
212 Bnb {
213 weight: Tensor,
214 bias: Option<Tensor>,
215 params: BnbQuantParmas,
216 quant_ty: BnbQuantType,
217 },
218 BlockwiseFP8 {
219 weight: Tensor,
220 weight_scale_inv: Tensor,
221 bias: Option<Tensor>,
222 dequant_dtype: DType,
223 weight_block_size: Vec<usize>,
224 },
225 Afq {
226 weight: Tensor,
227 bias: Option<Tensor>,
228 bits: AfqBits,
229 group_size: AfqGroupSize,
230 },
231}
232
233pub struct MatMul;
236
237impl MatMul {
238 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
240 #[cfg(feature = "accelerate")]
241 {
242 let original_dtype = a.dtype();
243 a.to_dtype(DType::F32)?
244 .matmul(&b.to_dtype(DType::F32)?)?
245 .to_dtype(original_dtype)
246 }
247 #[cfg(not(feature = "accelerate"))]
248 {
249 if a.device().is_cpu() {
250 let original_dtype = a.dtype();
251 a.to_dtype(DType::F16)?
252 .matmul(&b.to_dtype(DType::F16)?)?
253 .to_dtype(original_dtype)
254 } else {
255 a.matmul(b)
256 }
257 }
258 }
259
260 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
263 self.matmul(a, b)? / scale
265 }
266
267 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
270 self.matmul(a, b)? * scale
272 }
273
274 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
276 matmul.forward(x)
277 }
278
279 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
281 matmul.forward(x)
282 }
283}
284
285#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
286pub enum IsqType {
287 Q4_0,
288 Q4_1,
289 Q5_0,
290 Q5_1,
291 Q8_0,
292 Q8_1,
293 Q2K,
294 Q3K,
295 Q4K,
296 Q5K,
297 Q6K,
298 Q8K,
299 HQQ8,
300 HQQ4,
301 F8E4M3,
305 AFQ8,
306 AFQ6,
307 AFQ4,
308 AFQ3,
309 AFQ2,
310}
311
312impl IsqType {
313 pub fn pack_factor(&self, dtype: DType) -> usize {
316 match self {
317 Self::Q4_0 | Self::AFQ4 => {
318 (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size()) / GgmlDType::Q4_0.type_size()
319 }
320 Self::Q4_1 => {
321 (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size()) / GgmlDType::Q4_1.type_size()
322 }
323 Self::Q5_0 => {
324 (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size()) / GgmlDType::Q5_0.type_size()
325 }
326 Self::Q5_1 => {
327 (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size()) / GgmlDType::Q5_1.type_size()
328 }
329 Self::Q8_0 | Self::AFQ8 => {
330 (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size()) / GgmlDType::Q8_0.type_size()
331 }
332 Self::Q8_1 => {
333 (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size()) / GgmlDType::Q8_1.type_size()
334 }
335 Self::Q2K | Self::AFQ2 => {
336 (dtype.size_in_bytes() * GgmlDType::Q2K.block_size()) / GgmlDType::Q2K.type_size()
337 }
338 Self::Q3K | Self::AFQ3 => {
339 (dtype.size_in_bytes() * GgmlDType::Q3K.block_size()) / GgmlDType::Q3K.type_size()
340 }
341 Self::Q4K => {
342 (dtype.size_in_bytes() * GgmlDType::Q4K.block_size()) / GgmlDType::Q4K.type_size()
343 }
344 Self::Q5K => {
345 (dtype.size_in_bytes() * GgmlDType::Q5K.block_size()) / GgmlDType::Q5K.type_size()
346 }
347 Self::Q6K | Self::AFQ6 => {
348 (dtype.size_in_bytes() * GgmlDType::Q6K.block_size()) / GgmlDType::Q6K.type_size()
349 }
350 Self::Q8K => {
351 (dtype.size_in_bytes() * GgmlDType::Q8K.block_size()) / GgmlDType::Q8K.type_size()
352 }
353 Self::HQQ4 => 4,
355 Self::HQQ8 => 2,
356 Self::F8E4M3 => 2,
357 }
358 }
359
360 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
361 match self {
362 IsqType::HQQ4
364 | IsqType::HQQ8
365 | IsqType::AFQ2
366 | IsqType::AFQ3
367 | IsqType::AFQ4
368 | IsqType::AFQ6
369 | IsqType::AFQ8 => {
370 Some(1.try_into().unwrap())
372 }
373 IsqType::F8E4M3 => None,
374 IsqType::Q2K
375 | IsqType::Q3K
376 | IsqType::Q4K
377 | IsqType::Q4_0
378 | IsqType::Q4_1
379 | IsqType::Q5K
380 | IsqType::Q5_0
381 | IsqType::Q5_1
382 | IsqType::Q6K
383 | IsqType::Q8K
384 | IsqType::Q8_0
385 | IsqType::Q8_1 => None,
386 }
387 }
388}
389
390impl TryFrom<IsqType> for GgmlDType {
391 type Error = candle_core::Error;
392
393 fn try_from(value: IsqType) -> Result<Self> {
394 let tp = match value {
395 IsqType::Q2K => Self::Q2K,
396 IsqType::Q3K => Self::Q3K,
397 IsqType::Q4K => Self::Q4K,
398 IsqType::Q4_0 => Self::Q4_0,
399 IsqType::Q4_1 => Self::Q4_1,
400 IsqType::Q5K => Self::Q5K,
401 IsqType::Q5_0 => Self::Q5_0,
402 IsqType::Q5_1 => Self::Q5_1,
403 IsqType::Q6K => Self::Q6K,
404 IsqType::Q8K => Self::Q8K,
405 IsqType::Q8_0 => Self::Q8_0,
406 IsqType::Q8_1 => Self::Q8_1,
407 _ => candle_core::bail!("Expected valid GGML ISQ type."),
408 };
409 #[cfg(feature = "cuda")]
410 {
411 if !matches!(
412 tp,
413 GgmlDType::Q4_0
414 | GgmlDType::Q4_1
415 | GgmlDType::Q5_0
416 | GgmlDType::Q5_1
417 | GgmlDType::Q8_0
418 | GgmlDType::Q2K
419 | GgmlDType::Q3K
420 | GgmlDType::Q4K
421 | GgmlDType::Q5K
422 | GgmlDType::Q6K
423 ) {
424 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
425 }
426 }
427 Ok(tp)
428 }
429}
430
431impl TryFrom<GgmlDType> for IsqType {
432 type Error = candle_core::Error;
433
434 fn try_from(value: GgmlDType) -> Result<Self> {
435 match value {
436 GgmlDType::Q2K => Ok(Self::Q2K),
437 GgmlDType::Q3K => Ok(Self::Q3K),
438 GgmlDType::Q4K => Ok(Self::Q4K),
439 GgmlDType::Q5K => Ok(Self::Q5K),
440 GgmlDType::Q6K => Ok(Self::Q6K),
441 GgmlDType::Q4_0 => Ok(Self::Q4_0),
442 GgmlDType::Q4_1 => Ok(Self::Q4_1),
443 GgmlDType::Q5_0 => Ok(Self::Q5_0),
444 GgmlDType::Q5_1 => Ok(Self::Q5_1),
445 GgmlDType::Q8_0 => Ok(Self::Q8_0),
446 GgmlDType::Q8_1 => Ok(Self::Q8_1),
447 GgmlDType::Q8K => Ok(Self::Q8K),
448 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
449 candle_core::bail!("Expected valid GGML ISQ type.")
450 }
451 }
452 }
453}
454
455pub enum QuantizedSerdeType {
456 Gguf = 0,
457 Unquant = 1,
458 Hqq = 2,
459 Fp8 = 3,
460 Afq = 4,
461}
462
463impl TryFrom<usize> for QuantizedSerdeType {
464 type Error = candle_core::Error;
465 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
466 match value {
467 0 => Ok(Self::Gguf),
468 1 => Ok(Self::Unquant),
469 2 => Ok(Self::Hqq),
470 3 => Ok(Self::Fp8),
471 4 => Ok(Self::Afq),
472 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
473 }
474 }
475}
476
477pub trait QuantizedSerde {
478 fn name(&self) -> &'static str;
479 fn isq_serde_supported(&self) -> bool {
480 false
481 }
482 fn serialize(&self) -> Result<Cow<[u8]>> {
483 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
484 }
485 fn deserialize(
486 _data: Cow<[u8]>,
487 _device: &Device,
488 _comm: &Arc<crate::Comm>,
489 _guard: QuantizeOntoGuard,
490 ) -> Result<Arc<dyn QuantMethod>>
491 where
492 Self: Sized,
493 {
494 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
495 }
496 fn deserialize_ext_bias(
497 _data: Cow<[u8]>,
498 _device: &Device,
499 _guard: QuantizeOntoGuard,
500 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
501 where
502 Self: Sized,
503 {
504 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
505 }
506 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<[u8]>> {
508 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
509 }
510}
511
512#[derive(Clone)]
514#[allow(unused)]
515pub struct QuantizeOntoGuard(Arc<Mutex<()>>);
516
517pub enum QuantizeOntoDropGuard<'a> {
519 Real(MutexGuard<'a, ()>),
520 Fake,
521}
522
523impl Default for QuantizeOntoGuard {
524 fn default() -> Self {
525 Self::new()
526 }
527}
528
529impl QuantizeOntoGuard {
530 pub fn new() -> Self {
531 Self(Arc::new(Mutex::new(())))
532 }
533
534 pub fn acquire(&self) -> QuantizeOntoDropGuard<'_> {
535 #[cfg(feature = "cuda")]
536 {
537 QuantizeOntoDropGuard::Fake
538 }
539
540 #[cfg(not(feature = "cuda"))]
541 {
542 QuantizeOntoDropGuard::Real(self.0.lock().expect("QuantizeOntoGuard was poisoned!"))
543 }
544 }
545}
546
547pub enum DistributedKind {
548 ColumnParallel,
549 RowParallel,
550 Replicated,
551}
552
553pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
555 fn new(method: QuantMethodConfig) -> Result<Self>
556 where
557 Self: Sized;
558
559 fn dequantize_w(&self) -> Result<Tensor>;
560
561 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
564 let original_ty = a.dtype();
565 let a = if let Some(t) = self.quantized_act_type() {
566 a.to_dtype(t)?
567 } else {
568 a.clone()
569 };
570 self.forward(&a)?.to_dtype(original_ty)
571 }
572
573 fn forward(&self, a: &Tensor) -> Result<Tensor>;
575
576 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
582 let original_ty = a.dtype();
583 let a = if let Some(t) = self.quantized_act_type() {
584 a.to_dtype(t)?
585 } else {
586 a.clone()
587 };
588 self.gather_forward(&a, indices)?.to_dtype(original_ty)
589 }
590
591 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
596 candle_core::bail!(
597 "{} does not support `gather_forward`. Please raise an issue.",
598 self.name()
599 )
600 }
601
602 fn quantized_act_type(&self) -> Option<DType>;
604
605 fn dtype_and_device(&self) -> (DType, Device);
607
608 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
610
611 fn apply_isq(
613 self: Arc<Self>,
614 dtype: Option<IsqType>,
615 device: Device,
616 n_quantized: &AtomicUsize,
617 imatrix_weight: Option<Vec<f32>>,
618 guard: QuantizeOntoGuard,
619 ) -> Result<Arc<dyn QuantMethod>>;
620
621 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
622 None
623 }
624
625 fn begin_track_stats(&mut self) -> Result<()> {
627 candle_core::bail!("`{}` does not support tracking stats.", self.name())
628 }
629
630 fn end_track_stats(&self) -> Result<Tensor> {
632 candle_core::bail!("`{}` does not support tracking stats.", self.name())
633 }
634
635 fn is_distributed(&self) -> Option<DistributedKind> {
636 None
637 }
638}
639
640impl Module for dyn QuantMethod {
641 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
642 Self::forward(self, xs)
643 }
644}
645
646pub fn linear_no_bias(
647 in_dim: usize,
648 out_dim: usize,
649 config: &Option<QuantizedConfig>,
650 vb: ShardedVarBuilder,
651) -> Result<Arc<dyn QuantMethod>> {
652 let layer = if let Some(quant_conf) = &config {
653 match quant_conf {
654 QuantizedConfig::Gptq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
655 QuantizedConfig::Fp8 { .. } => {
656 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, Default::default(), vb)?
657 }
658 QuantizedConfig::Bitsandbytes { .. } => {
659 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
660 }
661 QuantizedConfig::Afq { .. } => {
662 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
663 }
664 }
665 } else {
666 if !vb.contains_tensor("weight") {
668 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
669 Arc::new(layer) as Arc<dyn QuantMethod>
670 } else {
671 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
672 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
673
674 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
675 Linear::new(weight, None),
676 ))?;
677 Arc::new(layer) as Arc<dyn QuantMethod>
678 }
679 };
680 Ok(layer)
681}
682
683pub fn linear(
684 in_dim: usize,
685 out_dim: usize,
686 config: &Option<QuantizedConfig>,
687 vb: ShardedVarBuilder,
688) -> Result<Arc<dyn QuantMethod>> {
689 let layer = if let Some(quant_conf) = &config {
690 match quant_conf {
691 QuantizedConfig::Gptq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
692 QuantizedConfig::Fp8 { .. } => {
693 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, true, Default::default(), vb)?
694 }
695 QuantizedConfig::Bitsandbytes { .. } => {
696 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
697 }
698 QuantizedConfig::Afq { .. } => {
699 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
700 }
701 }
702 } else {
703 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
705 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
706 Arc::new(layer) as Arc<dyn QuantMethod>
707 } else {
708 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
709 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
710 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
711
712 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
713 Linear::new(weight, Some(bias)),
714 ))?;
715 Arc::new(layer) as Arc<dyn QuantMethod>
716 }
717 };
718 Ok(layer)
719}
720
721pub fn linear_b(
722 in_dim: usize,
723 out_dim: usize,
724 bias: bool,
725 config: &Option<QuantizedConfig>,
726 vb: ShardedVarBuilder,
727) -> Result<Arc<dyn QuantMethod>> {
728 if bias {
729 linear(in_dim, out_dim, config, vb)
730 } else {
731 linear_no_bias(in_dim, out_dim, config, vb)
732 }
733}