1use std::sync::Arc;
2
3use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
4use candle_nn::Linear;
5
6use crate::{
7 blockwise_fp8::blockwise_fp8_linear_b,
8 distributed,
9 gptq::gptq_linear,
10 lora::merge_lora_weights,
11 should_apply_immediate_isq,
12 utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
13 AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, MXFP4Layer,
14 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
15 QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
16};
17
18use super::{Comm, SumAllReduce};
19
20fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
21 Shard::Simple {
22 dim,
23 rank,
24 world_size,
25 }
26}
27
28#[derive(Debug)]
31pub struct RowParallelLayer {
32 weight: Arc<dyn QuantMethod>,
33 bias: Option<Tensor>,
34 all_reduce: distributed::SumAllReduce,
35}
36
37impl RowParallelLayer {
38 #[allow(clippy::new_ret_no_self)]
39 pub fn new(
40 in_dim: usize,
41 out_dim: usize,
42 config: &Option<QuantizedConfig>,
43 bias: bool,
44 comm: &Arc<crate::Comm>,
45 vb: ShardedVarBuilder,
46 ) -> Result<Arc<dyn QuantMethod>> {
47 let rank = comm.rank();
48 let world_size = comm.world_size();
49 let shard = shard(1, rank, world_size);
50
51 let base_vb = vb.clone();
52 let vb = if should_apply_immediate_isq(&vb) {
53 vb.set_device(Device::Cpu)
54 } else {
55 vb
56 };
57
58 let weight = if let Some(quant_conf) = &config {
59 if matches!(
61 quant_conf,
62 QuantizedConfig::GptqAwq { .. }
63 | QuantizedConfig::Bitsandbytes { .. }
64 | QuantizedConfig::Afq { .. }
65 ) && comm.world_size() != 1
66 {
67 candle_core::bail!(
68 "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
69 comm.world_size()
70 );
71 }
72
73 match quant_conf {
74 QuantizedConfig::GptqAwq { .. } => {
75 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
76 }
77 QuantizedConfig::Fp8 { .. } => {
78 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
80 }
81 QuantizedConfig::Bitsandbytes { .. } => {
82 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
83 }
84 QuantizedConfig::Afq { .. } => {
85 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
86 }
87 QuantizedConfig::MXFP4 {} => {
88 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
89 }
90 }
91 } else {
92 if !vb.contains_tensor("weight") {
94 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
95 Arc::new(layer) as Arc<dyn QuantMethod>
96 } else {
97 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
98 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
99
100 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
101 Linear::new(weight, None),
102 ))?;
103 Arc::new(layer) as Arc<dyn QuantMethod>
104 }
105 };
106
107 let bias = if bias && vb.contains_tensor("bias") {
109 Some(vb.get((out_dim,), "bias")?)
110 } else {
111 None
112 };
113
114 let this_unquant = Arc::new(Self {
115 weight,
116 bias,
117 all_reduce: distributed::SumAllReduce::new(comm),
118 });
119 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
120 Ok(this)
121 }
122
123 #[allow(clippy::new_ret_no_self)]
124 pub fn new_matformer(
125 in_dim: usize,
126 out_dim: usize,
127 orig_intermediate_size: usize,
128 config: &Option<QuantizedConfig>,
129 bias: bool,
130 comm: &Arc<crate::Comm>,
131 vb: ShardedVarBuilder,
132 ) -> Result<Arc<dyn QuantMethod>> {
133 let rank = comm.rank();
134 let world_size = comm.world_size();
135 let shard = shard(1, rank, world_size);
136
137 let base_vb = vb.clone();
138 let vb = if should_apply_immediate_isq(&vb) {
139 vb.set_device(Device::Cpu)
140 } else {
141 vb
142 };
143
144 if config.is_some() {
145 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
146 }
147
148 let weight = if !vb.contains_tensor("weight") {
150 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
151 Arc::new(layer) as Arc<dyn QuantMethod>
152 } else {
153 let weight = vb
154 .get_with_hints(
155 (out_dim, orig_intermediate_size),
156 "weight",
157 Default::default(),
158 )?
159 .i((.., ..in_dim))?
160 .contiguous()?;
161
162 let weight = shard.apply_to(&weight)?;
163 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
164
165 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
166 Linear::new(weight, None),
167 ))?;
168 Arc::new(layer) as Arc<dyn QuantMethod>
169 };
170
171 let bias = if bias && vb.contains_tensor("bias") {
173 Some(vb.get((out_dim,), "bias")?)
174 } else {
175 None
176 };
177
178 let this_unquant = Arc::new(Self {
179 weight,
180 bias,
181 all_reduce: distributed::SumAllReduce::new(comm),
182 });
183 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
184 Ok(this)
185 }
186}
187
188impl QuantMethod for RowParallelLayer {
189 fn new(_method: QuantMethodConfig) -> Result<Self>
190 where
191 Self: Sized,
192 {
193 candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
194 }
195
196 fn forward(&self, a: &Tensor) -> Result<Tensor> {
197 let mut xs = self.weight.forward(a)?;
198 xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
199 if let Some(bias) = &self.bias {
200 xs = xs.broadcast_add(bias)?;
201 }
202 Ok(xs)
203 }
204
205 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
206 let weight = self.weight.add_delta_w(delta)?;
207 Ok(Arc::new(Self {
208 weight,
209 bias: self.bias.clone(),
210 all_reduce: self.all_reduce.clone(),
211 }))
212 }
213
214 fn dequantize_w(&self) -> Result<Tensor> {
215 self.weight.dequantize_w()
216 }
217
218 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
219 self.weight.dtype_and_device()
220 }
221
222 fn begin_track_stats(&mut self) -> Result<()> {
223 Arc::get_mut(&mut self.weight)
224 .context("Failed to get &mut to weight")?
225 .begin_track_stats()
226 }
227
228 fn end_track_stats(&self) -> Result<Tensor> {
229 self.weight.end_track_stats()
230 }
231
232 fn quantized_act_type(&self) -> Option<candle_core::DType> {
233 self.weight.quantized_act_type()
234 }
235
236 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
237 self.weight.unquant_weight_bias()
238 }
239
240 fn apply_isq(
241 self: Arc<Self>,
242 dtype: Option<crate::IsqType>,
243 device: candle_core::Device,
244 n_quantized: &std::sync::atomic::AtomicUsize,
245 imatrix_weight: Option<Vec<f32>>,
246 guard: QuantizeOntoGuard,
247 ) -> Result<Arc<dyn QuantMethod>> {
248 let weight =
249 self.weight
250 .clone()
251 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
252 let bias = match &self.bias {
253 Some(b) => {
254 let (dtype, device) = weight.dtype_and_device();
255 Some(b.to_device(&device)?.to_dtype(dtype)?)
256 }
257 None => None,
258 };
259 Ok(Arc::new(Self {
260 weight,
261 bias,
262 all_reduce: self.all_reduce.clone(),
263 }))
264 }
265
266 fn is_distributed(&self) -> Option<DistributedKind> {
267 Some(DistributedKind::RowParallel)
268 }
269}
270
271impl QuantizedSerde for RowParallelLayer {
272 fn isq_serde_supported(&self) -> bool {
273 self.weight.isq_serde_supported()
274 }
275 fn name(&self) -> &'static str {
276 self.weight.name()
277 }
278 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
279 self.weight.serialize_with_bias(self.bias.clone())
280 }
281 fn deserialize(
282 data: std::borrow::Cow<[u8]>,
283 device: &candle_core::Device,
284 comm: &Arc<crate::Comm>,
285 guard: QuantizeOntoGuard,
286 ) -> Result<Arc<dyn QuantMethod>>
287 where
288 Self: Sized,
289 {
290 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
292 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
293 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
294 QuantizedSerdeType::Unquant => {
295 UnquantLinear::deserialize_ext_bias(data, device, guard)?
296 }
297 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
298 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
299 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
300 };
301 Ok(Arc::new(Self {
302 weight,
303 bias,
304 all_reduce: SumAllReduce::new(comm),
305 }))
306 }
307}
308
309#[derive(Debug)]
310pub struct ColumnParallelLayer {
313 weight: Arc<dyn QuantMethod>,
314 bias: Option<Tensor>,
315}
316
317impl ColumnParallelLayer {
318 #[allow(clippy::new_ret_no_self)]
319 pub fn new_with_shard(
320 in_dim: usize,
321 out_dim: usize,
322 config: &Option<QuantizedConfig>,
323 bias: bool,
324 comm: &Arc<crate::Comm>,
325 shard: Shard,
326 vb: ShardedVarBuilder,
327 ) -> Result<Arc<dyn QuantMethod>> {
328 let base_vb = vb.clone();
329 let vb = if should_apply_immediate_isq(&vb) {
330 vb.set_device(Device::Cpu)
331 } else {
332 vb
333 };
334
335 let weight = if let Some(quant_conf) = &config {
336 if matches!(
338 quant_conf,
339 QuantizedConfig::GptqAwq { .. }
340 | QuantizedConfig::Bitsandbytes { .. }
341 | QuantizedConfig::Afq { .. }
342 ) && comm.world_size() != 1
343 {
344 candle_core::bail!(
345 "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
346 comm.world_size()
347 );
348 }
349
350 match quant_conf {
351 QuantizedConfig::GptqAwq { .. } => {
352 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
353 }
354 QuantizedConfig::Fp8 { .. } => {
355 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
357 }
358 QuantizedConfig::Bitsandbytes { .. } => {
359 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
360 }
361 QuantizedConfig::Afq { .. } => {
362 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
363 }
364 QuantizedConfig::MXFP4 {} => {
365 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
366 }
367 }
368 } else {
369 if !vb.contains_tensor("weight") {
371 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
372 Arc::new(layer) as Arc<dyn QuantMethod>
373 } else {
374 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
375 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
376
377 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
378 Linear::new(weight, None),
379 ))?;
380 Arc::new(layer) as Arc<dyn QuantMethod>
381 }
382 };
383
384 let bias = if bias && vb.contains_tensor("bias") {
386 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
387 } else {
388 None
389 };
390
391 let this_unquant = Arc::new(Self { weight, bias });
392 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
393 Ok(this)
394 }
395
396 #[allow(clippy::new_ret_no_self)]
397 pub fn new(
398 in_dim: usize,
399 out_dim: usize,
400 config: &Option<QuantizedConfig>,
401 bias: bool,
402 comm: &Arc<crate::Comm>,
403 vb: ShardedVarBuilder,
404 ) -> Result<Arc<dyn QuantMethod>> {
405 let rank = comm.rank();
406 let world_size = comm.world_size();
407 let shard = shard(0, rank, world_size);
408
409 Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
410 }
411
412 #[allow(clippy::new_ret_no_self)]
413 pub fn new_matformer(
414 in_dim: usize,
415 out_dim: usize,
416 orig_intermediate_size: usize,
417 config: &Option<QuantizedConfig>,
418 bias: bool,
419 comm: &Arc<crate::Comm>,
420 vb: ShardedVarBuilder,
421 ) -> Result<Arc<dyn QuantMethod>> {
422 let rank = comm.rank();
423 let world_size = comm.world_size();
424 let shard = shard(0, rank, world_size);
425
426 let base_vb = vb.clone();
427 let vb = if should_apply_immediate_isq(&vb) {
428 vb.set_device(Device::Cpu)
429 } else {
430 vb
431 };
432
433 if config.is_some() {
434 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
435 }
436
437 let weight = if !vb.contains_tensor("weight") {
439 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
440 Arc::new(layer) as Arc<dyn QuantMethod>
441 } else {
442 let weight = vb
443 .get_with_hints(
444 (orig_intermediate_size, in_dim),
445 "weight",
446 Default::default(),
447 )?
448 .i((..out_dim, ..))?
449 .contiguous()?;
450
451 let weight = shard.apply_to(&weight)?;
452 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
453
454 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
455 Linear::new(weight, None),
456 ))?;
457 Arc::new(layer) as Arc<dyn QuantMethod>
458 };
459
460 let bias = if bias && vb.contains_tensor("bias") {
462 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
463 } else {
464 None
465 };
466
467 let this_unquant = Arc::new(Self { weight, bias });
468 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
469 Ok(this)
470 }
471
472 pub fn new_merged(
473 in_dim: usize,
474 out_dim: usize,
475 chunks: usize,
476 config: &Option<QuantizedConfig>,
477 bias: bool,
478 comm: &Arc<crate::Comm>,
479 vb: ShardedVarBuilder,
480 ) -> Result<Vec<Arc<dyn QuantMethod>>> {
481 let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
482 for chunk_idx in 0..chunks {
483 let layer = ColumnParallelLayer::new_with_shard(
484 in_dim,
485 out_dim,
486 config,
487 bias,
488 comm,
489 shard(
490 0,
491 chunk_idx * comm.world_size() + comm.rank(),
492 chunks * comm.world_size(),
493 ),
494 vb.clone(),
495 )?;
496 vec_layers.push(layer);
497 }
498 Ok(vec_layers)
499 }
500}
501
502impl QuantMethod for ColumnParallelLayer {
503 fn new(_method: QuantMethodConfig) -> Result<Self>
504 where
505 Self: Sized,
506 {
507 candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
508 }
509
510 fn forward(&self, a: &Tensor) -> Result<Tensor> {
511 let mut xs = self.weight.forward(a)?;
512 if let Some(bias) = &self.bias {
513 xs = xs.broadcast_add(bias)?;
514 }
515 Ok(xs)
516 }
517
518 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
519 let weight = self.weight.add_delta_w(delta)?;
520 Ok(Arc::new(Self {
521 weight,
522 bias: self.bias.clone(),
523 }))
524 }
525
526 fn dequantize_w(&self) -> Result<Tensor> {
527 self.weight.dequantize_w()
528 }
529
530 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
531 self.weight.dtype_and_device()
532 }
533
534 fn begin_track_stats(&mut self) -> Result<()> {
535 Arc::get_mut(&mut self.weight)
536 .context("Failed to get &mut to weight")?
537 .begin_track_stats()
538 }
539
540 fn end_track_stats(&self) -> Result<Tensor> {
541 self.weight.end_track_stats()
542 }
543
544 fn quantized_act_type(&self) -> Option<candle_core::DType> {
545 self.weight.quantized_act_type()
546 }
547
548 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
549 self.weight.unquant_weight_bias()
550 }
551
552 fn apply_isq(
553 self: Arc<Self>,
554 dtype: Option<crate::IsqType>,
555 device: candle_core::Device,
556 n_quantized: &std::sync::atomic::AtomicUsize,
557 imatrix_weight: Option<Vec<f32>>,
558 guard: QuantizeOntoGuard,
559 ) -> Result<Arc<dyn QuantMethod>> {
560 let weight =
561 self.weight
562 .clone()
563 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
564 let bias = match &self.bias {
565 Some(b) => {
566 let (dtype, device) = weight.dtype_and_device();
567 Some(b.to_device(&device)?.to_dtype(dtype)?)
568 }
569 None => None,
570 };
571 Ok(Arc::new(Self { weight, bias }))
572 }
573
574 fn is_distributed(&self) -> Option<DistributedKind> {
575 Some(DistributedKind::ColumnParallel)
576 }
577}
578
579impl QuantizedSerde for ColumnParallelLayer {
580 fn isq_serde_supported(&self) -> bool {
581 self.weight.isq_serde_supported()
582 }
583 fn name(&self) -> &'static str {
584 self.weight.name()
585 }
586 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
587 self.weight.serialize_with_bias(self.bias.clone())
588 }
589 fn deserialize(
590 data: std::borrow::Cow<[u8]>,
591 device: &candle_core::Device,
592 _comm: &Arc<crate::Comm>,
593 guard: QuantizeOntoGuard,
594 ) -> Result<Arc<dyn QuantMethod>>
595 where
596 Self: Sized,
597 {
598 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
600 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
601 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
602 QuantizedSerdeType::Unquant => {
603 UnquantLinear::deserialize_ext_bias(data, device, guard)?
604 }
605 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
606 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
607 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
608 };
609 Ok(Arc::new(Self { weight, bias }))
610 }
611}
612
613#[derive(Debug)]
614pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
616
617impl ReplicatedLayer {
618 pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
619 let dev = lin.weight().device().clone();
620 let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
621 let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
622 Ok(this)
623 }
624
625 #[allow(clippy::new_ret_no_self)]
626 pub fn new(
627 in_dim: usize,
628 out_dim: usize,
629 config: &Option<QuantizedConfig>,
630 bias: bool,
631 vb: ShardedVarBuilder,
632 ) -> Result<Arc<dyn QuantMethod>> {
633 let base_vb = vb.clone();
634 let vb = if should_apply_immediate_isq(&vb) {
635 vb.set_device(Device::Cpu)
636 } else {
637 vb
638 };
639
640 let layer = if let Some(quant_conf) = &config {
641 match quant_conf {
642 QuantizedConfig::GptqAwq { .. } => {
643 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
644 }
645 QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
646 in_dim,
647 out_dim,
648 quant_conf,
649 bias,
650 Default::default(),
651 vb.clone(),
652 )?,
653 QuantizedConfig::Bitsandbytes { .. } => {
654 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
655 }
656 QuantizedConfig::Afq { .. } => {
657 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
658 }
659 QuantizedConfig::MXFP4 {} => {
660 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
661 }
662 }
663 } else {
664 if !vb.contains_tensor("weight") {
666 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
667 Arc::new(layer) as Arc<dyn QuantMethod>
668 } else {
669 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
670 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
671
672 let bias = if bias {
673 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
674 } else {
675 None
676 };
677 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
678 Linear::new(weight, bias),
679 ))?;
680 Arc::new(layer) as Arc<dyn QuantMethod>
681 }
682 };
683
684 let this_unquant = Arc::new(Self(layer));
685 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
686 Ok(this)
687 }
688
689 #[allow(clippy::new_ret_no_self)]
690 pub fn new_layers_matformer_indices(
691 in_dim: usize,
692 out_dim: usize,
693 kept_layers_indices: Option<&Tensor>,
694 orig_num_hidden_layers: usize,
695 config: &Option<QuantizedConfig>,
696 bias: bool,
697 vb: ShardedVarBuilder,
698 ) -> Result<Arc<dyn QuantMethod>> {
699 let base_vb = vb.clone();
700 let vb = if should_apply_immediate_isq(&vb) {
701 vb.set_device(Device::Cpu)
702 } else {
703 vb
704 };
705
706 let layer = if let Some(quant_conf) = &config {
707 if kept_layers_indices.is_some() {
708 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
709 }
710
711 match quant_conf {
712 QuantizedConfig::GptqAwq { .. } => {
713 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
714 }
715 QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
716 in_dim,
717 out_dim,
718 quant_conf,
719 bias,
720 Default::default(),
721 vb.clone(),
722 )?,
723 QuantizedConfig::Bitsandbytes { .. } => {
724 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
725 }
726 QuantizedConfig::Afq { .. } => {
727 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
728 }
729 QuantizedConfig::MXFP4 {} => {
730 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
731 }
732 }
733 } else {
734 if !vb.contains_tensor("weight") {
736 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
737 Arc::new(layer) as Arc<dyn QuantMethod>
738 } else {
739 let mut weight =
740 vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
741
742 if let Some(kept_layers_indices) = &kept_layers_indices {
743 let weight_reshaped = weight.reshape((
744 orig_num_hidden_layers,
745 weight.dim(0)? / orig_num_hidden_layers,
746 weight.dim(1)?,
747 ))?;
748
749 weight = weight_reshaped
750 .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
751 .reshape(((), weight_reshaped.dim(D::Minus1)?))?
752 .contiguous()?;
753 }
754
755 weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
756
757 let bias = if bias {
758 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
759 } else {
760 None
761 };
762 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
763 Linear::new(weight, bias),
764 ))?;
765 Arc::new(layer) as Arc<dyn QuantMethod>
766 }
767 };
768
769 let this_unquant = Arc::new(Self(layer));
770 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
771 Ok(this)
772 }
773}
774
775impl QuantMethod for ReplicatedLayer {
776 fn new(_method: QuantMethodConfig) -> Result<Self>
777 where
778 Self: Sized,
779 {
780 candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
781 }
782
783 fn forward(&self, a: &Tensor) -> Result<Tensor> {
784 self.0.forward(a)
785 }
786
787 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
788 self.0.add_delta_w(delta)
789 }
790
791 fn dequantize_w(&self) -> Result<Tensor> {
792 self.0.dequantize_w()
793 }
794
795 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
796 self.0.dtype_and_device()
797 }
798
799 fn begin_track_stats(&mut self) -> Result<()> {
800 Arc::get_mut(&mut self.0)
801 .context("Failed to get &mut to weight")?
802 .begin_track_stats()
803 }
804
805 fn end_track_stats(&self) -> Result<Tensor> {
806 self.0.end_track_stats()
807 }
808
809 fn quantized_act_type(&self) -> Option<candle_core::DType> {
810 self.0.quantized_act_type()
811 }
812
813 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
814 self.0.unquant_weight_bias()
815 }
816
817 fn apply_isq(
818 self: Arc<Self>,
819 dtype: Option<crate::IsqType>,
820 device: candle_core::Device,
821 n_quantized: &std::sync::atomic::AtomicUsize,
822 imatrix_weight: Option<Vec<f32>>,
823 guard: QuantizeOntoGuard,
824 ) -> Result<Arc<dyn QuantMethod>> {
825 self.0
826 .clone()
827 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
828 }
829
830 fn is_distributed(&self) -> Option<DistributedKind> {
831 Some(DistributedKind::Replicated)
832 }
833}
834
835impl QuantizedSerde for ReplicatedLayer {
836 fn isq_serde_supported(&self) -> bool {
837 self.0.isq_serde_supported()
838 }
839 fn name(&self) -> &'static str {
840 self.0.name()
841 }
842 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
843 self.0.serialize()
844 }
845 fn deserialize(
846 data: std::borrow::Cow<[u8]>,
847 device: &candle_core::Device,
848 comm: &Arc<crate::Comm>,
849 guard: QuantizeOntoGuard,
850 ) -> Result<Arc<dyn QuantMethod>>
851 where
852 Self: Sized,
853 {
854 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
856 let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
857 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
858 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
859 QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
860 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
861 QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
862 };
863 Ok(Arc::new(Self(deserialized)))
864 }
865}
866
867#[derive(Debug)]
868pub struct PackedExperts {
869 pub gate_proj: Vec<Arc<dyn QuantMethod>>,
870 pub up_proj: Vec<Arc<dyn QuantMethod>>,
871 pub down_proj: Vec<Arc<dyn QuantMethod>>,
872}
873
874impl PackedExperts {
875 #[allow(clippy::too_many_arguments)]
877 pub fn new(
878 num_local_experts: usize,
879 hidden_size: usize,
880 intermediate_size: usize,
881 config: &Option<QuantizedConfig>,
882 bias: bool,
883 comm: &Arc<crate::Comm>,
884 vb: ShardedVarBuilder,
885 ) -> Result<Self> {
886 if bias {
887 candle_core::bail!("PackedExperts does not support bias.");
888 }
889
890 let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
891 if comm.world_size() != 1 {
893 candle_core::bail!(
894 "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
895 comm.world_size()
896 );
897 }
898
899 match quant_conf {
900 QuantizedConfig::Afq { .. } => {
901 if !vb.contains_tensor("gate_up_proj")
902 || !vb.contains_tensor("gate_up_proj.weight")
903 {
904 candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
905 }
906
907 let base_vb = vb.clone();
908
909 let vb_gate_proj = if should_apply_immediate_isq(&vb) {
910 vb.pp("gate_proj").set_device(Device::Cpu)
911 } else {
912 vb.pp("gate_proj")
913 };
914 let vb_up_proj = if should_apply_immediate_isq(&vb) {
915 vb.pp("up_proj").set_device(Device::Cpu)
916 } else {
917 vb.pp("up_proj")
918 };
919 let vb_down_proj = if should_apply_immediate_isq(&vb) {
920 vb.pp("down_proj").set_device(Device::Cpu)
921 } else {
922 vb.pp("down_proj")
923 };
924 let mut gate_proj = AfqLayer::afq_packed_linear_b(
925 num_local_experts,
926 hidden_size,
927 intermediate_size,
928 quant_conf,
929 bias,
930 vb_gate_proj,
931 )?;
932 let mut up_proj = AfqLayer::afq_packed_linear_b(
933 num_local_experts,
934 hidden_size,
935 intermediate_size,
936 quant_conf,
937 bias,
938 vb_up_proj,
939 )?;
940 let mut down_proj = AfqLayer::afq_packed_linear_b(
941 num_local_experts,
942 intermediate_size,
943 hidden_size,
944 quant_conf,
945 bias,
946 vb_down_proj,
947 )?;
948
949 gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
950 up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
951 down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
952
953 (vec![gate_proj], vec![up_proj], vec![down_proj])
954 }
955 _ => candle_core::bail!(
956 "PackedExperts with quantization config only allows AFQ quantization"
957 ),
958 }
959 } else if !vb.contains_tensor("gate_up_proj") {
960 let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
962 let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
963 let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
964 for _ in 0..num_local_experts {
965 gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
966 us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
967 ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
968 }
969 (gs, us, ds)
970 } else {
971 let gate_up_block_size = intermediate_size / comm.world_size();
979 let gate_up_start = gate_up_block_size * comm.rank();
980
981 let shard_gate = Shard::Offset {
983 dim: 2,
984 offset: gate_up_start,
985 len: gate_up_block_size,
986 };
987 let shard_up = Shard::Offset {
988 dim: 2,
989 offset: intermediate_size + gate_up_start,
990 len: gate_up_block_size,
991 };
992 let shard_down = Shard::Simple {
993 dim: 1,
994 rank: comm.rank(),
995 world_size: comm.world_size(),
996 };
997
998 let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
999 vb.pp("gate_up_proj").set_device(Device::Cpu)
1000 } else {
1001 vb.pp("gate_up_proj")
1002 };
1003 let vb_down_proj = if should_apply_immediate_isq(&vb) {
1004 vb.pp("down_proj").set_device(Device::Cpu)
1005 } else {
1006 vb.pp("down_proj")
1007 };
1008
1009 let gate_proj = vb
1010 .get_with_hints(
1011 (num_local_experts, hidden_size, intermediate_size * 2),
1012 "gate_up_proj",
1013 shard_gate,
1014 )?
1015 .t()?
1016 .contiguous()?;
1017 let up_proj = vb
1018 .get_with_hints(
1019 (num_local_experts, hidden_size, intermediate_size * 2),
1020 "gate_up_proj",
1021 shard_up,
1022 )?
1023 .t()?
1024 .contiguous()?;
1025 let down_proj = vb
1026 .get_with_hints(
1027 (num_local_experts, intermediate_size, hidden_size),
1028 "down_proj",
1029 shard_down,
1030 )?
1031 .t()?
1032 .contiguous()?;
1033
1034 let gc = gate_proj.chunk(num_local_experts, 0)?;
1035 let uc = up_proj.chunk(num_local_experts, 0)?;
1036 let dc = down_proj.chunk(num_local_experts, 0)?;
1037 drop((gate_proj, up_proj, down_proj));
1038
1039 let mut gs = Vec::new();
1040 let mut us = Vec::new();
1041 let mut ds = Vec::new();
1042 for ((mut gate_proj, mut up_proj), mut down_proj) in
1043 gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
1044 {
1045 gate_proj = gate_proj.squeeze(0)?;
1046 up_proj = up_proj.squeeze(0)?;
1047 down_proj = down_proj.squeeze(0)?;
1048 let gate_proj = merge_lora_weights(
1049 &vb,
1050 gate_proj,
1051 hidden_size,
1052 intermediate_size * 2,
1053 shard_gate,
1054 )?;
1055 let up_proj =
1056 merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1057 let down_proj =
1058 merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1059
1060 let mut gate_proj: Arc<dyn QuantMethod> =
1061 Arc::new(<UnquantLinear as QuantMethod>::new(
1062 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1063 )?);
1064 gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1065 let mut up_proj: Arc<dyn QuantMethod> =
1066 Arc::new(<UnquantLinear as QuantMethod>::new(
1067 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1068 )?);
1069 up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1070 let mut down_proj: Arc<dyn QuantMethod> =
1071 Arc::new(<UnquantLinear as QuantMethod>::new(
1072 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1073 )?);
1074 down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1075 gs.push(gate_proj);
1076 us.push(up_proj);
1077 ds.push(down_proj);
1078 }
1079 (gs, us, ds)
1080 };
1081
1082 Ok(Self {
1083 gate_proj,
1084 up_proj,
1085 down_proj,
1086 })
1087 }
1088}
1089
1090pub struct FusedExperts {
1091 pub fused_gate_proj: Arc<dyn QuantMethod>,
1092 pub fused_up_proj: Arc<dyn QuantMethod>,
1093 pub fused_down_proj: Arc<dyn QuantMethod>,
1094}
1095
1096impl FusedExperts {
1097 pub fn new(
1098 hidden_size: usize,
1099 moe_intermediate_size: usize,
1100 num_experts: usize,
1101 quantization_config: &Option<QuantizedConfig>,
1102 vb: ShardedVarBuilder,
1103 ) -> Result<Self> {
1104 if !vb.device().is_metal() {
1105 candle_core::bail!("FastMoeMlp requires Metal.");
1106 }
1107
1108 let (fused_gate_proj, fused_up_proj, fused_down_proj) =
1109 if matches!(&quantization_config, Some(QuantizedConfig::Afq { .. })) {
1110 let quantization_config = quantization_config.as_ref().unwrap();
1111
1112 let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1113 num_experts,
1114 hidden_size,
1115 moe_intermediate_size,
1116 quantization_config,
1117 false,
1118 vb.pp("switch_mlp.gate_proj"),
1119 )?;
1120 let fused_up_proj = AfqLayer::afq_packed_linear_b(
1121 num_experts,
1122 hidden_size,
1123 moe_intermediate_size,
1124 quantization_config,
1125 false,
1126 vb.pp("switch_mlp.up_proj"),
1127 )?;
1128 let fused_down_proj = AfqLayer::afq_packed_linear_b(
1129 num_experts,
1130 moe_intermediate_size,
1131 hidden_size,
1132 quantization_config,
1133 false,
1134 vb.pp("switch_mlp.down_proj"),
1135 )?;
1136
1137 (fused_gate_proj, fused_up_proj, fused_down_proj)
1138 } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1139 let experts_vb = vb.pp("experts");
1140 let mut gate_proj_vec = Vec::new();
1141 let mut up_proj_vec = Vec::new();
1142 let mut down_proj_vec = Vec::new();
1143 for i in 0..num_experts {
1144 let vb = experts_vb.pp(i);
1145
1146 let gate_proj = crate::linear_no_bias(
1147 hidden_size,
1148 moe_intermediate_size,
1149 quantization_config,
1150 vb.pp("gate_proj.weight"),
1151 )?;
1152 let up_proj = crate::linear_no_bias(
1153 hidden_size,
1154 moe_intermediate_size,
1155 quantization_config,
1156 vb.pp("up_proj.weight"),
1157 )?;
1158 let down_proj = crate::linear_no_bias(
1159 moe_intermediate_size,
1160 hidden_size,
1161 quantization_config,
1162 vb.pp("down_proj.weight"),
1163 )?;
1164
1165 gate_proj_vec.push(gate_proj.dequantize_w()?);
1166 up_proj_vec.push(up_proj.dequantize_w()?);
1167 down_proj_vec.push(down_proj.dequantize_w()?);
1168 }
1169
1170 let mut gate_proj: Arc<dyn QuantMethod> =
1171 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1172 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1173 ))?);
1174 let mut up_proj: Arc<dyn QuantMethod> =
1175 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1176 Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
1177 ))?);
1178 let mut down_proj: Arc<dyn QuantMethod> =
1179 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1180 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1181 ))?);
1182 gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1183 up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1184 down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1185
1186 (gate_proj, up_proj, down_proj)
1187 } else {
1188 let experts_vb = vb.pp("experts");
1189 let mut gate_proj_vec = Vec::new();
1190 let mut up_proj_vec = Vec::new();
1191 let mut down_proj_vec = Vec::new();
1192 for i in 0..num_experts {
1193 let vb = experts_vb.pp(i);
1194 let gate_proj =
1195 vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1196 let up_proj = vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1197 let down_proj =
1198 vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1199
1200 gate_proj_vec.push(gate_proj);
1201 up_proj_vec.push(up_proj);
1202 down_proj_vec.push(down_proj);
1203 }
1204
1205 let mut gate_proj: Arc<dyn QuantMethod> =
1206 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1207 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1208 ))?);
1209 let mut up_proj: Arc<dyn QuantMethod> =
1210 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1211 Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
1212 ))?);
1213 let mut down_proj: Arc<dyn QuantMethod> =
1214 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1215 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1216 ))?);
1217 gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1218 up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1219 down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1220
1221 (gate_proj, up_proj, down_proj)
1222 };
1223
1224 Ok(Self {
1225 fused_gate_proj,
1226 fused_up_proj,
1227 fused_down_proj,
1228 })
1229 }
1230}
1231
1232pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1234 if comm.world_size() == 1 {
1235 return Shard::default();
1236 }
1237
1238 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1242 comm.world_size() / total_num_kv_heads
1243 } else {
1244 return Shard::Simple {
1245 dim: 0,
1246 rank: comm.rank(),
1247 world_size: comm.world_size(),
1248 };
1249 };
1250
1251 let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1252 let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1253 Shard::Offset {
1254 dim: 0,
1255 offset: kv_shard_id * head_dim,
1256 len: head_dim,
1257 }
1258}
1259
1260pub fn compute_n_kv_groups(
1262 total_num_kv_heads: usize,
1263 num_attention_heads: usize,
1264 comm: &Comm,
1265) -> usize {
1266 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1267 comm.world_size() / total_num_kv_heads
1268 } else {
1269 1
1270 };
1271 if kv_replicate != 0 {
1272 (num_attention_heads / total_num_kv_heads) / kv_replicate
1273 } else {
1274 num_attention_heads / total_num_kv_heads
1275 }
1276}