1use std::sync::Arc;
2
3use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
4use candle_nn::Linear;
5
6use crate::{
7 blockwise_fp8::{blockwise_fp8_linear_b, blockwise_fp8_moe},
8 distributed,
9 gptq::gptq_linear,
10 lora::merge_lora_weights,
11 pertensor_fp8::pertensor_fp8_linear_b,
12 should_apply_immediate_isq,
13 utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
14 AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, MXFP4Layer,
15 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
16 QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
17};
18
19use super::{Comm, SumAllReduce};
20
21fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
22 Shard::Simple {
23 dim,
24 rank,
25 world_size,
26 }
27}
28
29#[derive(Debug)]
32pub struct RowParallelLayer {
33 weight: Arc<dyn QuantMethod>,
34 bias: Option<Tensor>,
35 all_reduce: distributed::SumAllReduce,
36}
37
38impl RowParallelLayer {
39 #[allow(clippy::new_ret_no_self)]
40 pub fn new(
41 in_dim: usize,
42 out_dim: usize,
43 config: &Option<QuantizedConfig>,
44 bias: bool,
45 comm: &Arc<crate::Comm>,
46 vb: ShardedVarBuilder,
47 ) -> Result<Arc<dyn QuantMethod>> {
48 let rank = comm.rank();
49 let world_size = comm.world_size();
50 let shard = shard(1, rank, world_size);
51
52 let base_vb = vb.clone();
53 let vb = if should_apply_immediate_isq(&vb) {
54 vb.set_device(Device::Cpu)
55 } else {
56 vb
57 };
58
59 let weight = if let Some(quant_conf) = &config {
60 if matches!(
62 quant_conf,
63 QuantizedConfig::GptqAwq { .. }
64 | QuantizedConfig::Bitsandbytes { .. }
65 | QuantizedConfig::Afq { .. }
66 ) && comm.world_size() != 1
67 {
68 candle_core::bail!(
69 "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
70 comm.world_size()
71 );
72 }
73
74 match quant_conf {
75 QuantizedConfig::GptqAwq { .. } => {
76 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
77 }
78 QuantizedConfig::Fp8 { weight_block_size } => {
79 if weight_block_size.is_some() {
81 blockwise_fp8_linear_b(
82 in_dim,
83 out_dim,
84 quant_conf,
85 false,
86 shard,
87 vb.clone(),
88 )?
89 } else {
90 pertensor_fp8_linear_b(
91 in_dim,
92 out_dim,
93 quant_conf,
94 false,
95 shard,
96 vb.clone(),
97 )?
98 }
99 }
100 QuantizedConfig::Bitsandbytes { .. } => {
101 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
102 }
103 QuantizedConfig::Afq { .. } => {
104 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
105 }
106 QuantizedConfig::MXFP4 {} => {
107 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
108 }
109 }
110 } else {
111 if !vb.contains_tensor("weight") {
113 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
114 Arc::new(layer) as Arc<dyn QuantMethod>
115 } else {
116 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
117 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
118
119 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
120 Linear::new(weight, None),
121 ))?;
122 Arc::new(layer) as Arc<dyn QuantMethod>
123 }
124 };
125
126 let bias = if bias && vb.contains_tensor("bias") {
128 Some(vb.get((out_dim,), "bias")?)
129 } else {
130 None
131 };
132
133 let this_unquant = Arc::new(Self {
134 weight,
135 bias,
136 all_reduce: distributed::SumAllReduce::new(comm),
137 });
138 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
139 Ok(this)
140 }
141
142 #[allow(clippy::new_ret_no_self)]
143 pub fn new_matformer(
144 in_dim: usize,
145 out_dim: usize,
146 orig_intermediate_size: usize,
147 config: &Option<QuantizedConfig>,
148 bias: bool,
149 comm: &Arc<crate::Comm>,
150 vb: ShardedVarBuilder,
151 ) -> Result<Arc<dyn QuantMethod>> {
152 let rank = comm.rank();
153 let world_size = comm.world_size();
154 let shard = shard(1, rank, world_size);
155
156 let base_vb = vb.clone();
157 let vb = if should_apply_immediate_isq(&vb) {
158 vb.set_device(Device::Cpu)
159 } else {
160 vb
161 };
162
163 if config.is_some() {
164 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
165 }
166
167 let weight = if !vb.contains_tensor("weight") {
169 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
170 Arc::new(layer) as Arc<dyn QuantMethod>
171 } else {
172 let weight = vb
173 .get_with_hints(
174 (out_dim, orig_intermediate_size),
175 "weight",
176 Default::default(),
177 )?
178 .i((.., ..in_dim))?
179 .contiguous()?;
180
181 let weight = shard.apply_to(&weight)?;
182 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
183
184 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
185 Linear::new(weight, None),
186 ))?;
187 Arc::new(layer) as Arc<dyn QuantMethod>
188 };
189
190 let bias = if bias && vb.contains_tensor("bias") {
192 Some(vb.get((out_dim,), "bias")?)
193 } else {
194 None
195 };
196
197 let this_unquant = Arc::new(Self {
198 weight,
199 bias,
200 all_reduce: distributed::SumAllReduce::new(comm),
201 });
202 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
203 Ok(this)
204 }
205}
206
207impl QuantMethod for RowParallelLayer {
208 fn new(_method: QuantMethodConfig) -> Result<Self>
209 where
210 Self: Sized,
211 {
212 candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
213 }
214
215 fn forward(&self, a: &Tensor) -> Result<Tensor> {
216 let mut xs = self.weight.forward(a)?;
217 xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
218 if let Some(bias) = &self.bias {
219 xs = xs.broadcast_add(bias)?;
220 }
221 Ok(xs)
222 }
223
224 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
225 let weight = self.weight.add_delta_w(delta)?;
226 Ok(Arc::new(Self {
227 weight,
228 bias: self.bias.clone(),
229 all_reduce: self.all_reduce.clone(),
230 }))
231 }
232
233 fn dequantize_w(&self) -> Result<Tensor> {
234 self.weight.dequantize_w()
235 }
236
237 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
238 self.weight.dtype_and_device()
239 }
240
241 fn begin_track_stats(&mut self) -> Result<()> {
242 Arc::get_mut(&mut self.weight)
243 .context("Failed to get &mut to weight")?
244 .begin_track_stats()
245 }
246
247 fn end_track_stats(&self) -> Result<Tensor> {
248 self.weight.end_track_stats()
249 }
250
251 fn quantized_act_type(&self) -> Option<candle_core::DType> {
252 self.weight.quantized_act_type()
253 }
254
255 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
256 self.weight.unquant_weight_bias()
257 }
258
259 fn apply_isq(
260 self: Arc<Self>,
261 dtype: Option<crate::IsqType>,
262 device: candle_core::Device,
263 n_quantized: &std::sync::atomic::AtomicUsize,
264 imatrix_weight: Option<Vec<f32>>,
265 guard: QuantizeOntoGuard,
266 ) -> Result<Arc<dyn QuantMethod>> {
267 let weight =
268 self.weight
269 .clone()
270 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
271 let bias = match &self.bias {
272 Some(b) => {
273 let (dtype, device) = weight.dtype_and_device();
274 Some(b.to_device(&device)?.to_dtype(dtype)?)
275 }
276 None => None,
277 };
278 Ok(Arc::new(Self {
279 weight,
280 bias,
281 all_reduce: self.all_reduce.clone(),
282 }))
283 }
284
285 fn is_distributed(&self) -> Option<DistributedKind> {
286 Some(DistributedKind::RowParallel)
287 }
288}
289
290impl QuantizedSerde for RowParallelLayer {
291 fn isq_serde_supported(&self) -> bool {
292 self.weight.isq_serde_supported()
293 }
294 fn name(&self) -> &'static str {
295 self.weight.name()
296 }
297 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
298 self.weight.serialize_with_bias(self.bias.clone())
299 }
300 fn deserialize(
301 data: std::borrow::Cow<[u8]>,
302 device: &candle_core::Device,
303 comm: &Arc<crate::Comm>,
304 guard: QuantizeOntoGuard,
305 ) -> Result<Arc<dyn QuantMethod>>
306 where
307 Self: Sized,
308 {
309 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
311 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
312 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
313 QuantizedSerdeType::Unquant => {
314 UnquantLinear::deserialize_ext_bias(data, device, guard)?
315 }
316 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
317 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
318 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
319 };
320 Ok(Arc::new(Self {
321 weight,
322 bias,
323 all_reduce: SumAllReduce::new(comm),
324 }))
325 }
326}
327
328#[derive(Debug)]
329pub struct ColumnParallelLayer {
332 weight: Arc<dyn QuantMethod>,
333 bias: Option<Tensor>,
334}
335
336impl ColumnParallelLayer {
337 #[allow(clippy::new_ret_no_self)]
338 pub fn new_with_shard(
339 in_dim: usize,
340 out_dim: usize,
341 config: &Option<QuantizedConfig>,
342 bias: bool,
343 comm: &Arc<crate::Comm>,
344 shard: Shard,
345 vb: ShardedVarBuilder,
346 ) -> Result<Arc<dyn QuantMethod>> {
347 let base_vb = vb.clone();
348 let vb = if should_apply_immediate_isq(&vb) {
349 vb.set_device(Device::Cpu)
350 } else {
351 vb
352 };
353
354 let weight = if let Some(quant_conf) = &config {
355 if matches!(
357 quant_conf,
358 QuantizedConfig::GptqAwq { .. }
359 | QuantizedConfig::Bitsandbytes { .. }
360 | QuantizedConfig::Afq { .. }
361 ) && comm.world_size() != 1
362 {
363 candle_core::bail!(
364 "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
365 comm.world_size()
366 );
367 }
368
369 match quant_conf {
370 QuantizedConfig::GptqAwq { .. } => {
371 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
372 }
373 QuantizedConfig::Fp8 { weight_block_size } => {
374 if weight_block_size.is_some() {
376 blockwise_fp8_linear_b(
377 in_dim,
378 out_dim,
379 quant_conf,
380 false,
381 shard,
382 vb.clone(),
383 )?
384 } else {
385 pertensor_fp8_linear_b(
386 in_dim,
387 out_dim,
388 quant_conf,
389 false,
390 shard,
391 vb.clone(),
392 )?
393 }
394 }
395 QuantizedConfig::Bitsandbytes { .. } => {
396 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
397 }
398 QuantizedConfig::Afq { .. } => {
399 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
400 }
401 QuantizedConfig::MXFP4 {} => {
402 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
403 }
404 }
405 } else {
406 if !vb.contains_tensor("weight") {
408 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
409 Arc::new(layer) as Arc<dyn QuantMethod>
410 } else {
411 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
412 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
413
414 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
415 Linear::new(weight, None),
416 ))?;
417 Arc::new(layer) as Arc<dyn QuantMethod>
418 }
419 };
420
421 let bias = if bias && vb.contains_tensor("bias") {
423 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
424 } else {
425 None
426 };
427
428 let this_unquant = Arc::new(Self { weight, bias });
429 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
430 Ok(this)
431 }
432
433 #[allow(clippy::new_ret_no_self)]
434 pub fn new(
435 in_dim: usize,
436 out_dim: usize,
437 config: &Option<QuantizedConfig>,
438 bias: bool,
439 comm: &Arc<crate::Comm>,
440 vb: ShardedVarBuilder,
441 ) -> Result<Arc<dyn QuantMethod>> {
442 let rank = comm.rank();
443 let world_size = comm.world_size();
444 let shard = shard(0, rank, world_size);
445
446 Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
447 }
448
449 #[allow(clippy::new_ret_no_self)]
450 pub fn new_matformer(
451 in_dim: usize,
452 out_dim: usize,
453 orig_intermediate_size: usize,
454 config: &Option<QuantizedConfig>,
455 bias: bool,
456 comm: &Arc<crate::Comm>,
457 vb: ShardedVarBuilder,
458 ) -> Result<Arc<dyn QuantMethod>> {
459 let rank = comm.rank();
460 let world_size = comm.world_size();
461 let shard = shard(0, rank, world_size);
462
463 let base_vb = vb.clone();
464 let vb = if should_apply_immediate_isq(&vb) {
465 vb.set_device(Device::Cpu)
466 } else {
467 vb
468 };
469
470 if config.is_some() {
471 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
472 }
473
474 let weight = if !vb.contains_tensor("weight") {
476 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
477 Arc::new(layer) as Arc<dyn QuantMethod>
478 } else {
479 let weight = vb
480 .get_with_hints(
481 (orig_intermediate_size, in_dim),
482 "weight",
483 Default::default(),
484 )?
485 .i((..out_dim, ..))?
486 .contiguous()?;
487
488 let weight = shard.apply_to(&weight)?;
489 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
490
491 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
492 Linear::new(weight, None),
493 ))?;
494 Arc::new(layer) as Arc<dyn QuantMethod>
495 };
496
497 let bias = if bias && vb.contains_tensor("bias") {
499 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
500 } else {
501 None
502 };
503
504 let this_unquant = Arc::new(Self { weight, bias });
505 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
506 Ok(this)
507 }
508
509 pub fn new_merged(
510 in_dim: usize,
511 out_dim: usize,
512 chunks: usize,
513 config: &Option<QuantizedConfig>,
514 bias: bool,
515 comm: &Arc<crate::Comm>,
516 vb: ShardedVarBuilder,
517 ) -> Result<Vec<Arc<dyn QuantMethod>>> {
518 let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
519 for chunk_idx in 0..chunks {
520 let layer = ColumnParallelLayer::new_with_shard(
521 in_dim,
522 out_dim,
523 config,
524 bias,
525 comm,
526 shard(
527 0,
528 chunk_idx * comm.world_size() + comm.rank(),
529 chunks * comm.world_size(),
530 ),
531 vb.clone(),
532 )?;
533 vec_layers.push(layer);
534 }
535 Ok(vec_layers)
536 }
537}
538
539impl QuantMethod for ColumnParallelLayer {
540 fn new(_method: QuantMethodConfig) -> Result<Self>
541 where
542 Self: Sized,
543 {
544 candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
545 }
546
547 fn forward(&self, a: &Tensor) -> Result<Tensor> {
548 let mut xs = self.weight.forward(a)?;
549 if let Some(bias) = &self.bias {
550 xs = xs.broadcast_add(bias)?;
551 }
552 Ok(xs)
553 }
554
555 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
556 let weight = self.weight.add_delta_w(delta)?;
557 Ok(Arc::new(Self {
558 weight,
559 bias: self.bias.clone(),
560 }))
561 }
562
563 fn dequantize_w(&self) -> Result<Tensor> {
564 self.weight.dequantize_w()
565 }
566
567 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
568 self.weight.dtype_and_device()
569 }
570
571 fn begin_track_stats(&mut self) -> Result<()> {
572 Arc::get_mut(&mut self.weight)
573 .context("Failed to get &mut to weight")?
574 .begin_track_stats()
575 }
576
577 fn end_track_stats(&self) -> Result<Tensor> {
578 self.weight.end_track_stats()
579 }
580
581 fn quantized_act_type(&self) -> Option<candle_core::DType> {
582 self.weight.quantized_act_type()
583 }
584
585 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
586 self.weight.unquant_weight_bias()
587 }
588
589 fn apply_isq(
590 self: Arc<Self>,
591 dtype: Option<crate::IsqType>,
592 device: candle_core::Device,
593 n_quantized: &std::sync::atomic::AtomicUsize,
594 imatrix_weight: Option<Vec<f32>>,
595 guard: QuantizeOntoGuard,
596 ) -> Result<Arc<dyn QuantMethod>> {
597 let weight =
598 self.weight
599 .clone()
600 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
601 let bias = match &self.bias {
602 Some(b) => {
603 let (dtype, device) = weight.dtype_and_device();
604 Some(b.to_device(&device)?.to_dtype(dtype)?)
605 }
606 None => None,
607 };
608 Ok(Arc::new(Self { weight, bias }))
609 }
610
611 fn is_distributed(&self) -> Option<DistributedKind> {
612 Some(DistributedKind::ColumnParallel)
613 }
614}
615
616impl QuantizedSerde for ColumnParallelLayer {
617 fn isq_serde_supported(&self) -> bool {
618 self.weight.isq_serde_supported()
619 }
620 fn name(&self) -> &'static str {
621 self.weight.name()
622 }
623 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
624 self.weight.serialize_with_bias(self.bias.clone())
625 }
626 fn deserialize(
627 data: std::borrow::Cow<[u8]>,
628 device: &candle_core::Device,
629 _comm: &Arc<crate::Comm>,
630 guard: QuantizeOntoGuard,
631 ) -> Result<Arc<dyn QuantMethod>>
632 where
633 Self: Sized,
634 {
635 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
637 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
638 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
639 QuantizedSerdeType::Unquant => {
640 UnquantLinear::deserialize_ext_bias(data, device, guard)?
641 }
642 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
643 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
644 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
645 };
646 Ok(Arc::new(Self { weight, bias }))
647 }
648}
649
650#[derive(Debug)]
651pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
653
654impl ReplicatedLayer {
655 pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
656 let dev = lin.weight().device().clone();
657 let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
658 let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
659 Ok(this)
660 }
661
662 #[allow(clippy::new_ret_no_self)]
663 pub fn new(
664 in_dim: usize,
665 out_dim: usize,
666 config: &Option<QuantizedConfig>,
667 bias: bool,
668 vb: ShardedVarBuilder,
669 ) -> Result<Arc<dyn QuantMethod>> {
670 let base_vb = vb.clone();
671 let vb = if should_apply_immediate_isq(&vb) {
672 vb.set_device(Device::Cpu)
673 } else {
674 vb
675 };
676
677 let layer = if let Some(quant_conf) = &config {
678 match quant_conf {
679 QuantizedConfig::GptqAwq { .. } => {
680 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
681 }
682 QuantizedConfig::Fp8 { weight_block_size } => {
683 if weight_block_size.is_some() {
684 blockwise_fp8_linear_b(
685 in_dim,
686 out_dim,
687 quant_conf,
688 bias,
689 Default::default(),
690 vb.clone(),
691 )?
692 } else {
693 pertensor_fp8_linear_b(
694 in_dim,
695 out_dim,
696 quant_conf,
697 bias,
698 Default::default(),
699 vb.clone(),
700 )?
701 }
702 }
703 QuantizedConfig::Bitsandbytes { .. } => {
704 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
705 }
706 QuantizedConfig::Afq { .. } => {
707 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
708 }
709 QuantizedConfig::MXFP4 {} => {
710 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
711 }
712 }
713 } else {
714 if !vb.contains_tensor("weight") {
716 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
717 Arc::new(layer) as Arc<dyn QuantMethod>
718 } else {
719 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
720 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
721
722 let bias = if bias {
723 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
724 } else {
725 None
726 };
727 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
728 Linear::new(weight, bias),
729 ))?;
730 Arc::new(layer) as Arc<dyn QuantMethod>
731 }
732 };
733
734 let this_unquant = Arc::new(Self(layer));
735 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
736 Ok(this)
737 }
738
739 #[allow(clippy::new_ret_no_self)]
740 pub fn new_layers_matformer_indices(
741 in_dim: usize,
742 out_dim: usize,
743 kept_layers_indices: Option<&Tensor>,
744 orig_num_hidden_layers: usize,
745 config: &Option<QuantizedConfig>,
746 bias: bool,
747 vb: ShardedVarBuilder,
748 ) -> Result<Arc<dyn QuantMethod>> {
749 let base_vb = vb.clone();
750 let vb = if should_apply_immediate_isq(&vb) {
751 vb.set_device(Device::Cpu)
752 } else {
753 vb
754 };
755
756 let layer = if let Some(quant_conf) = &config {
757 if kept_layers_indices.is_some() {
758 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
759 }
760
761 match quant_conf {
762 QuantizedConfig::GptqAwq { .. } => {
763 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
764 }
765 QuantizedConfig::Fp8 { weight_block_size } => {
766 if weight_block_size.is_some() {
767 blockwise_fp8_linear_b(
768 in_dim,
769 out_dim,
770 quant_conf,
771 bias,
772 Default::default(),
773 vb.clone(),
774 )?
775 } else {
776 pertensor_fp8_linear_b(
777 in_dim,
778 out_dim,
779 quant_conf,
780 bias,
781 Default::default(),
782 vb.clone(),
783 )?
784 }
785 }
786 QuantizedConfig::Bitsandbytes { .. } => {
787 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
788 }
789 QuantizedConfig::Afq { .. } => {
790 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
791 }
792 QuantizedConfig::MXFP4 {} => {
793 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
794 }
795 }
796 } else {
797 if !vb.contains_tensor("weight") {
799 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
800 Arc::new(layer) as Arc<dyn QuantMethod>
801 } else {
802 let mut weight =
803 vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
804
805 if let Some(kept_layers_indices) = &kept_layers_indices {
806 let weight_reshaped = weight.reshape((
807 orig_num_hidden_layers,
808 weight.dim(0)? / orig_num_hidden_layers,
809 weight.dim(1)?,
810 ))?;
811
812 weight = weight_reshaped
813 .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
814 .reshape(((), weight_reshaped.dim(D::Minus1)?))?
815 .contiguous()?;
816 }
817
818 weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
819
820 let bias = if bias {
821 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
822 } else {
823 None
824 };
825 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
826 Linear::new(weight, bias),
827 ))?;
828 Arc::new(layer) as Arc<dyn QuantMethod>
829 }
830 };
831
832 let this_unquant = Arc::new(Self(layer));
833 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
834 Ok(this)
835 }
836}
837
838impl QuantMethod for ReplicatedLayer {
839 fn new(_method: QuantMethodConfig) -> Result<Self>
840 where
841 Self: Sized,
842 {
843 candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
844 }
845
846 fn forward(&self, a: &Tensor) -> Result<Tensor> {
847 self.0.forward(a)
848 }
849
850 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
851 self.0.add_delta_w(delta)
852 }
853
854 fn dequantize_w(&self) -> Result<Tensor> {
855 self.0.dequantize_w()
856 }
857
858 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
859 self.0.dtype_and_device()
860 }
861
862 fn begin_track_stats(&mut self) -> Result<()> {
863 Arc::get_mut(&mut self.0)
864 .context("Failed to get &mut to weight")?
865 .begin_track_stats()
866 }
867
868 fn end_track_stats(&self) -> Result<Tensor> {
869 self.0.end_track_stats()
870 }
871
872 fn quantized_act_type(&self) -> Option<candle_core::DType> {
873 self.0.quantized_act_type()
874 }
875
876 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
877 self.0.unquant_weight_bias()
878 }
879
880 fn apply_isq(
881 self: Arc<Self>,
882 dtype: Option<crate::IsqType>,
883 device: candle_core::Device,
884 n_quantized: &std::sync::atomic::AtomicUsize,
885 imatrix_weight: Option<Vec<f32>>,
886 guard: QuantizeOntoGuard,
887 ) -> Result<Arc<dyn QuantMethod>> {
888 self.0
889 .clone()
890 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
891 }
892
893 fn is_distributed(&self) -> Option<DistributedKind> {
894 Some(DistributedKind::Replicated)
895 }
896}
897
898impl QuantizedSerde for ReplicatedLayer {
899 fn isq_serde_supported(&self) -> bool {
900 self.0.isq_serde_supported()
901 }
902 fn name(&self) -> &'static str {
903 self.0.name()
904 }
905 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
906 self.0.serialize()
907 }
908 fn deserialize(
909 data: std::borrow::Cow<[u8]>,
910 device: &candle_core::Device,
911 comm: &Arc<crate::Comm>,
912 guard: QuantizeOntoGuard,
913 ) -> Result<Arc<dyn QuantMethod>>
914 where
915 Self: Sized,
916 {
917 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
919 let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
920 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
921 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
922 QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
923 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
924 QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
925 };
926 Ok(Arc::new(Self(deserialized)))
927 }
928}
929
930#[derive(Debug)]
931pub struct PackedExperts {
932 pub gate_proj: Vec<Arc<dyn QuantMethod>>,
933 pub up_proj: Vec<Arc<dyn QuantMethod>>,
934 pub down_proj: Vec<Arc<dyn QuantMethod>>,
935}
936
937impl PackedExperts {
938 #[allow(clippy::too_many_arguments)]
940 pub fn new(
941 num_local_experts: usize,
942 hidden_size: usize,
943 intermediate_size: usize,
944 config: &Option<QuantizedConfig>,
945 bias: bool,
946 comm: &Arc<crate::Comm>,
947 vb: ShardedVarBuilder,
948 ) -> Result<Self> {
949 if bias {
950 candle_core::bail!("PackedExperts does not support bias.");
951 }
952
953 let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
954 if comm.world_size() != 1 {
956 candle_core::bail!(
957 "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
958 comm.world_size()
959 );
960 }
961
962 match quant_conf {
963 QuantizedConfig::Afq { .. } => {
964 if !vb.contains_tensor("gate_up_proj")
965 || !vb.contains_tensor("gate_up_proj.weight")
966 {
967 candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
968 }
969
970 let base_vb = vb.clone();
971
972 let vb_gate_proj = if should_apply_immediate_isq(&vb) {
973 vb.pp("gate_proj").set_device(Device::Cpu)
974 } else {
975 vb.pp("gate_proj")
976 };
977 let vb_up_proj = if should_apply_immediate_isq(&vb) {
978 vb.pp("up_proj").set_device(Device::Cpu)
979 } else {
980 vb.pp("up_proj")
981 };
982 let vb_down_proj = if should_apply_immediate_isq(&vb) {
983 vb.pp("down_proj").set_device(Device::Cpu)
984 } else {
985 vb.pp("down_proj")
986 };
987 let mut gate_proj = AfqLayer::afq_packed_linear_b(
988 num_local_experts,
989 hidden_size,
990 intermediate_size,
991 quant_conf,
992 bias,
993 vb_gate_proj,
994 )?;
995 let mut up_proj = AfqLayer::afq_packed_linear_b(
996 num_local_experts,
997 hidden_size,
998 intermediate_size,
999 quant_conf,
1000 bias,
1001 vb_up_proj,
1002 )?;
1003 let mut down_proj = AfqLayer::afq_packed_linear_b(
1004 num_local_experts,
1005 intermediate_size,
1006 hidden_size,
1007 quant_conf,
1008 bias,
1009 vb_down_proj,
1010 )?;
1011
1012 gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
1013 up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
1014 down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
1015
1016 (vec![gate_proj], vec![up_proj], vec![down_proj])
1017 }
1018 QuantizedConfig::Fp8 { weight_block_size } => {
1019 let Some(weight_block_size) = weight_block_size else {
1022 candle_core::bail!("Blockwise FP8 for PackedExperts requires weight_block_size to be set.")
1023 };
1024 if weight_block_size.len() != 2 {
1025 candle_core::bail!(
1026 "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1027 );
1028 }
1029
1030 let is_stacked_format = vb.contains_tensor("gate_up_proj");
1033
1034 if is_stacked_format {
1035 let has_fp8_scales = vb.contains_tensor("gate_up_proj.weight_scale_inv");
1037
1038 if has_fp8_scales {
1039 let gate_up_fp8 = vb.get_with_hints_dtype(
1041 (num_local_experts, hidden_size, intermediate_size * 2),
1042 "gate_up_proj",
1043 Default::default(),
1044 candle_core::DType::F8E4M3,
1045 )?;
1046 let gate_up_scale = vb.get_with_hints_dtype(
1047 (
1048 num_local_experts,
1049 hidden_size.div_ceil(weight_block_size[0]),
1050 (intermediate_size * 2).div_ceil(weight_block_size[1]),
1051 ),
1052 "gate_up_proj.weight_scale_inv",
1053 Default::default(),
1054 candle_core::DType::F32,
1055 )?;
1056
1057 let down_fp8 = vb.get_with_hints_dtype(
1059 (num_local_experts, intermediate_size, hidden_size),
1060 "down_proj",
1061 Default::default(),
1062 candle_core::DType::F8E4M3,
1063 )?;
1064 let down_scale = vb.get_with_hints_dtype(
1065 (
1066 num_local_experts,
1067 intermediate_size.div_ceil(weight_block_size[0]),
1068 hidden_size.div_ceil(weight_block_size[1]),
1069 ),
1070 "down_proj.weight_scale_inv",
1071 Default::default(),
1072 candle_core::DType::F32,
1073 )?;
1074
1075 let mut gs = Vec::new();
1077 let mut us = Vec::new();
1078 let mut ds = Vec::new();
1079
1080 for i in 0..num_local_experts {
1081 let gate_up_expert =
1083 gate_up_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1084 let gate_up_scale_expert = gate_up_scale.i(i)?.contiguous()?;
1085 let down_expert = down_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1086 let down_scale_expert = down_scale.i(i)?.contiguous()?;
1087
1088 let gate_expert = gate_up_expert.narrow(0, 0, intermediate_size)?;
1090 let up_expert = gate_up_expert.narrow(
1091 0,
1092 intermediate_size,
1093 intermediate_size,
1094 )?;
1095
1096 let gate_scale_expert = gate_up_scale_expert.narrow(
1098 1,
1099 0,
1100 intermediate_size.div_ceil(weight_block_size[1]),
1101 )?;
1102 let up_scale_expert = gate_up_scale_expert.narrow(
1103 1,
1104 intermediate_size.div_ceil(weight_block_size[1]),
1105 intermediate_size.div_ceil(weight_block_size[1]),
1106 )?;
1107
1108 use crate::blockwise_fp8::BlockwiseFP8Linear;
1110 use crate::QuantMethodConfig;
1111
1112 let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1113 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1114 weight: gate_expert,
1115 weight_scale_inv: gate_scale_expert.transpose(0, 1)?,
1116 bias: None,
1117 dequant_dtype: vb.dtype(),
1118 weight_block_size: weight_block_size.clone(),
1119 })?,
1120 );
1121 let up_layer: Arc<dyn QuantMethod> = Arc::new(
1122 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1123 weight: up_expert,
1124 weight_scale_inv: up_scale_expert.transpose(0, 1)?,
1125 bias: None,
1126 dequant_dtype: vb.dtype(),
1127 weight_block_size: weight_block_size.clone(),
1128 })?,
1129 );
1130 let down_layer: Arc<dyn QuantMethod> = Arc::new(
1131 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1132 weight: down_expert,
1133 weight_scale_inv: down_scale_expert.transpose(0, 1)?,
1134 bias: None,
1135 dequant_dtype: vb.dtype(),
1136 weight_block_size: weight_block_size.clone(),
1137 })?,
1138 );
1139
1140 gs.push(gate_layer);
1141 us.push(up_layer);
1142 ds.push(down_layer);
1143 }
1144
1145 (gs, us, ds)
1146 } else {
1147 candle_core::bail!(
1148 "PackedExperts with FP8 requires weight_scale_inv tensors"
1149 );
1150 }
1151 } else {
1152 let mut gs = Vec::new();
1154 let mut us = Vec::new();
1155 let mut ds = Vec::new();
1156
1157 for i in 0..num_local_experts {
1158 let expert_vb = vb.pp(i);
1159
1160 let gate_fp8 = expert_vb.get_with_hints_dtype(
1162 (intermediate_size, hidden_size),
1163 "gate_proj.weight",
1164 Default::default(),
1165 candle_core::DType::F8E4M3,
1166 )?;
1167 let gate_scale = expert_vb.get_with_hints_dtype(
1168 (
1169 intermediate_size.div_ceil(weight_block_size[0]),
1170 hidden_size.div_ceil(weight_block_size[1]),
1171 ),
1172 "gate_proj.weight_scale_inv",
1173 Default::default(),
1174 candle_core::DType::F32,
1175 )?;
1176
1177 let up_fp8 = expert_vb.get_with_hints_dtype(
1178 (intermediate_size, hidden_size),
1179 "up_proj.weight",
1180 Default::default(),
1181 candle_core::DType::F8E4M3,
1182 )?;
1183 let up_scale = expert_vb.get_with_hints_dtype(
1184 (
1185 intermediate_size.div_ceil(weight_block_size[0]),
1186 hidden_size.div_ceil(weight_block_size[1]),
1187 ),
1188 "up_proj.weight_scale_inv",
1189 Default::default(),
1190 candle_core::DType::F32,
1191 )?;
1192
1193 let down_fp8 = expert_vb.get_with_hints_dtype(
1194 (hidden_size, intermediate_size),
1195 "down_proj.weight",
1196 Default::default(),
1197 candle_core::DType::F8E4M3,
1198 )?;
1199 let down_scale = expert_vb.get_with_hints_dtype(
1200 (
1201 hidden_size.div_ceil(weight_block_size[0]),
1202 intermediate_size.div_ceil(weight_block_size[1]),
1203 ),
1204 "down_proj.weight_scale_inv",
1205 Default::default(),
1206 candle_core::DType::F32,
1207 )?;
1208
1209 use crate::blockwise_fp8::BlockwiseFP8Linear;
1211 use crate::QuantMethodConfig;
1212
1213 let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1214 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1215 weight: gate_fp8,
1216 weight_scale_inv: gate_scale,
1217 bias: None,
1218 dequant_dtype: vb.dtype(),
1219 weight_block_size: weight_block_size.clone(),
1220 })?,
1221 );
1222 let up_layer: Arc<dyn QuantMethod> = Arc::new(BlockwiseFP8Linear::new(
1223 QuantMethodConfig::BlockwiseFP8 {
1224 weight: up_fp8,
1225 weight_scale_inv: up_scale,
1226 bias: None,
1227 dequant_dtype: vb.dtype(),
1228 weight_block_size: weight_block_size.clone(),
1229 },
1230 )?);
1231 let down_layer: Arc<dyn QuantMethod> = Arc::new(
1232 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1233 weight: down_fp8,
1234 weight_scale_inv: down_scale,
1235 bias: None,
1236 dequant_dtype: vb.dtype(),
1237 weight_block_size: weight_block_size.clone(),
1238 })?,
1239 );
1240
1241 gs.push(gate_layer);
1242 us.push(up_layer);
1243 ds.push(down_layer);
1244 }
1245
1246 (gs, us, ds)
1247 }
1248 }
1249 QuantizedConfig::MXFP4 {} => {
1250 let gate_proj = MXFP4Layer::packed_linear_b(
1254 num_local_experts,
1255 hidden_size,
1256 intermediate_size,
1257 quant_conf,
1258 bias,
1259 vb.pp("gate_proj"),
1260 )?;
1261 let up_proj = MXFP4Layer::packed_linear_b(
1262 num_local_experts,
1263 hidden_size,
1264 intermediate_size,
1265 quant_conf,
1266 bias,
1267 vb.pp("up_proj"),
1268 )?;
1269 let down_proj = MXFP4Layer::packed_linear_b(
1270 num_local_experts,
1271 intermediate_size,
1272 hidden_size,
1273 quant_conf,
1274 bias,
1275 vb.pp("down_proj"),
1276 )?;
1277
1278 (vec![gate_proj], vec![up_proj], vec![down_proj])
1279 }
1280 _ => candle_core::bail!(
1281 "PackedExperts with quantization config only allows AFQ, FP8, or MXFP4 quantization"
1282 ),
1283 }
1284 } else if !vb.contains_tensor("gate_up_proj") {
1285 let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
1287 let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
1288 let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
1289 for _ in 0..num_local_experts {
1290 gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1291 us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1292 ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1293 }
1294 (gs, us, ds)
1295 } else {
1296 let gate_up_block_size = intermediate_size / comm.world_size();
1304 let gate_up_start = gate_up_block_size * comm.rank();
1305
1306 let shard_gate = Shard::Offset {
1308 dim: 2,
1309 offset: gate_up_start,
1310 len: gate_up_block_size,
1311 };
1312 let shard_up = Shard::Offset {
1313 dim: 2,
1314 offset: intermediate_size + gate_up_start,
1315 len: gate_up_block_size,
1316 };
1317 let shard_down = Shard::Simple {
1318 dim: 1,
1319 rank: comm.rank(),
1320 world_size: comm.world_size(),
1321 };
1322
1323 let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
1324 vb.pp("gate_up_proj").set_device(Device::Cpu)
1325 } else {
1326 vb.pp("gate_up_proj")
1327 };
1328 let vb_down_proj = if should_apply_immediate_isq(&vb) {
1329 vb.pp("down_proj").set_device(Device::Cpu)
1330 } else {
1331 vb.pp("down_proj")
1332 };
1333
1334 let gate_proj = vb
1335 .get_with_hints(
1336 (num_local_experts, hidden_size, intermediate_size * 2),
1337 "gate_up_proj",
1338 shard_gate,
1339 )?
1340 .t()?
1341 .contiguous()?;
1342 let up_proj = vb
1343 .get_with_hints(
1344 (num_local_experts, hidden_size, intermediate_size * 2),
1345 "gate_up_proj",
1346 shard_up,
1347 )?
1348 .t()?
1349 .contiguous()?;
1350 let down_proj = vb
1351 .get_with_hints(
1352 (num_local_experts, intermediate_size, hidden_size),
1353 "down_proj",
1354 shard_down,
1355 )?
1356 .t()?
1357 .contiguous()?;
1358
1359 let gc = gate_proj.chunk(num_local_experts, 0)?;
1360 let uc = up_proj.chunk(num_local_experts, 0)?;
1361 let dc = down_proj.chunk(num_local_experts, 0)?;
1362 drop((gate_proj, up_proj, down_proj));
1363
1364 let mut gs = Vec::new();
1365 let mut us = Vec::new();
1366 let mut ds = Vec::new();
1367 for ((mut gate_proj, mut up_proj), mut down_proj) in
1368 gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
1369 {
1370 gate_proj = gate_proj.squeeze(0)?;
1371 up_proj = up_proj.squeeze(0)?;
1372 down_proj = down_proj.squeeze(0)?;
1373 let gate_proj = merge_lora_weights(
1374 &vb,
1375 gate_proj,
1376 hidden_size,
1377 intermediate_size * 2,
1378 shard_gate,
1379 )?;
1380 let up_proj =
1381 merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1382 let down_proj =
1383 merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1384
1385 let mut gate_proj: Arc<dyn QuantMethod> =
1386 Arc::new(<UnquantLinear as QuantMethod>::new(
1387 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1388 )?);
1389 gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1390 let mut up_proj: Arc<dyn QuantMethod> =
1391 Arc::new(<UnquantLinear as QuantMethod>::new(
1392 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1393 )?);
1394 up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1395 let mut down_proj: Arc<dyn QuantMethod> =
1396 Arc::new(<UnquantLinear as QuantMethod>::new(
1397 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1398 )?);
1399 down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1400 gs.push(gate_proj);
1401 us.push(up_proj);
1402 ds.push(down_proj);
1403 }
1404 (gs, us, ds)
1405 };
1406
1407 Ok(Self {
1408 gate_proj,
1409 up_proj,
1410 down_proj,
1411 })
1412 }
1413}
1414
1415pub struct FusedExperts {
1416 pub fused_gate_proj: Arc<dyn QuantMethod>,
1417 pub fused_up_proj: Arc<dyn QuantMethod>,
1418 pub fused_down_proj: Arc<dyn QuantMethod>,
1419}
1420
1421impl FusedExperts {
1422 pub fn new(
1423 hidden_size: usize,
1424 moe_intermediate_size: usize,
1425 num_experts: usize,
1426 quantization_config: &Option<QuantizedConfig>,
1427 vb: ShardedVarBuilder,
1428 ) -> Result<Self> {
1429 let experts_vb = vb.pp("experts");
1435 let is_stacked_format = experts_vb.contains_tensor("gate_up_proj");
1436
1437 let (fused_gate_proj, fused_up_proj, fused_down_proj) = if matches!(
1438 &quantization_config,
1439 Some(QuantizedConfig::Afq { .. })
1440 ) {
1441 let quantization_config = quantization_config.as_ref().unwrap();
1442
1443 let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1444 num_experts,
1445 hidden_size,
1446 moe_intermediate_size,
1447 quantization_config,
1448 false,
1449 vb.pp("switch_mlp.gate_proj"),
1450 )?;
1451 let fused_up_proj = AfqLayer::afq_packed_linear_b(
1452 num_experts,
1453 hidden_size,
1454 moe_intermediate_size,
1455 quantization_config,
1456 false,
1457 vb.pp("switch_mlp.up_proj"),
1458 )?;
1459 let fused_down_proj = AfqLayer::afq_packed_linear_b(
1460 num_experts,
1461 moe_intermediate_size,
1462 hidden_size,
1463 quantization_config,
1464 false,
1465 vb.pp("switch_mlp.down_proj"),
1466 )?;
1467
1468 (fused_gate_proj, fused_up_proj, fused_down_proj)
1469 } else if is_stacked_format
1470 && matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. }))
1471 {
1472 let has_fp8_scales = experts_vb.contains_tensor("gate_up_proj.weight_scale_inv");
1475
1476 if has_fp8_scales {
1477 let weight_block_size = match quantization_config {
1478 Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1479 _ => unreachable!(),
1480 };
1481
1482 let Some(weight_block_size) = weight_block_size else {
1483 candle_core::bail!(
1484 "Blockwise FP8 for stacked experts requires weight_block_size to be set."
1485 )
1486 };
1487 if weight_block_size.len() != 2 {
1488 candle_core::bail!(
1489 "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1490 );
1491 }
1492
1493 let gate_up_fp8 = experts_vb.get_with_hints_dtype(
1496 (num_experts, hidden_size, moe_intermediate_size * 2),
1497 "gate_up_proj",
1498 Default::default(),
1499 candle_core::DType::F8E4M3,
1500 )?;
1501 let gate_up_scale = experts_vb.get_with_hints_dtype(
1502 (
1503 num_experts,
1504 hidden_size.div_ceil(weight_block_size[0]),
1505 (moe_intermediate_size * 2).div_ceil(weight_block_size[1]),
1506 ),
1507 "gate_up_proj.weight_scale_inv",
1508 Default::default(),
1509 candle_core::DType::F32,
1510 )?;
1511
1512 let down_fp8 = experts_vb.get_with_hints_dtype(
1515 (num_experts, moe_intermediate_size, hidden_size),
1516 "down_proj",
1517 Default::default(),
1518 candle_core::DType::F8E4M3,
1519 )?;
1520 let down_scale = experts_vb.get_with_hints_dtype(
1521 (
1522 num_experts,
1523 moe_intermediate_size.div_ceil(weight_block_size[0]),
1524 hidden_size.div_ceil(weight_block_size[1]),
1525 ),
1526 "down_proj.weight_scale_inv",
1527 Default::default(),
1528 candle_core::DType::F32,
1529 )?;
1530
1531 let gate_fp8 = gate_up_fp8.narrow(2, 0, moe_intermediate_size)?;
1533 let up_fp8 = gate_up_fp8.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1534
1535 let gate_scale = gate_up_scale.narrow(
1537 2,
1538 0,
1539 moe_intermediate_size.div_ceil(weight_block_size[1]),
1540 )?;
1541 let up_scale = gate_up_scale.narrow(
1542 2,
1543 moe_intermediate_size.div_ceil(weight_block_size[1]),
1544 moe_intermediate_size.div_ceil(weight_block_size[1]),
1545 )?;
1546
1547 let gate_fp8 = gate_fp8.transpose(1, 2)?.contiguous()?;
1550 let up_fp8 = up_fp8.transpose(1, 2)?.contiguous()?;
1551 let down_fp8 = down_fp8.transpose(1, 2)?.contiguous()?;
1553
1554 let gate_scale = gate_scale.transpose(1, 2)?.contiguous()?;
1556 let up_scale = up_scale.transpose(1, 2)?.contiguous()?;
1557 let down_scale = down_scale.transpose(1, 2)?.contiguous()?;
1558
1559 let fused_gate_proj =
1561 blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1562 let fused_up_proj =
1563 blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1564 let fused_down_proj =
1565 blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1566
1567 (fused_gate_proj, fused_up_proj, fused_down_proj)
1568 } else {
1569 tracing::warn!(
1571 "FP8 quantization config specified but no scale tensors found for stacked MoE experts. \
1572 Loading as unquantized."
1573 );
1574 let gate_up_proj = experts_vb.get(
1575 (num_experts, hidden_size, moe_intermediate_size * 2),
1576 "gate_up_proj",
1577 )?;
1578 let down_proj_packed = experts_vb.get(
1579 (num_experts, moe_intermediate_size, hidden_size),
1580 "down_proj",
1581 )?;
1582
1583 let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1585 let up_proj =
1586 gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1587
1588 let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1590 let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1591 let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1592
1593 let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1594 QuantMethodConfig::Unquantized(Linear::new(gate_proj.clone(), None)),
1595 )?);
1596 let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1597 QuantMethodConfig::Unquantized(Linear::new(up_proj.clone(), None)),
1598 )?);
1599 let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1600 QuantMethodConfig::Unquantized(Linear::new(down_proj.clone(), None)),
1601 )?);
1602 let device = gate_proj.device();
1604 fused_gate_proj = apply_immediate_isq_always(fused_gate_proj, device)?;
1605 fused_up_proj = apply_immediate_isq_always(fused_up_proj, device)?;
1606 fused_down_proj = apply_immediate_isq_always(fused_down_proj, device)?;
1607
1608 (fused_gate_proj, fused_up_proj, fused_down_proj)
1609 }
1610 } else if is_stacked_format
1611 && matches!(&quantization_config, Some(QuantizedConfig::MXFP4 {}))
1612 {
1613 let quantization_config = quantization_config.as_ref().unwrap();
1617
1618 let fused_gate_proj = MXFP4Layer::packed_linear_b(
1623 num_experts,
1624 hidden_size,
1625 moe_intermediate_size,
1626 quantization_config,
1627 false,
1628 experts_vb.pp("gate_proj"),
1629 )?;
1630 let fused_up_proj = MXFP4Layer::packed_linear_b(
1631 num_experts,
1632 hidden_size,
1633 moe_intermediate_size,
1634 quantization_config,
1635 false,
1636 experts_vb.pp("up_proj"),
1637 )?;
1638 let fused_down_proj = MXFP4Layer::packed_linear_b(
1639 num_experts,
1640 moe_intermediate_size,
1641 hidden_size,
1642 quantization_config,
1643 false,
1644 experts_vb.pp("down_proj"),
1645 )?;
1646
1647 (fused_gate_proj, fused_up_proj, fused_down_proj)
1648 } else if is_stacked_format {
1649 let gate_up_proj = experts_vb.get(
1657 (num_experts, hidden_size, moe_intermediate_size * 2),
1658 "gate_up_proj",
1659 )?;
1660 let down_proj_packed = experts_vb.get(
1661 (num_experts, moe_intermediate_size, hidden_size),
1662 "down_proj",
1663 )?;
1664
1665 let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1669 let up_proj = gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1670
1671 let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1674 let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1675 let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1677
1678 let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1679 QuantMethodConfig::Unquantized(Linear::new(gate_proj.clone(), None)),
1680 )?);
1681 let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1682 QuantMethodConfig::Unquantized(Linear::new(up_proj.clone(), None)),
1683 )?);
1684 let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1685 QuantMethodConfig::Unquantized(Linear::new(down_proj.clone(), None)),
1686 )?);
1687 let device = gate_proj.device();
1689 fused_gate_proj = apply_immediate_isq_always(fused_gate_proj, device)?;
1690 fused_up_proj = apply_immediate_isq_always(fused_up_proj, device)?;
1691 fused_down_proj = apply_immediate_isq_always(fused_down_proj, device)?;
1692
1693 (fused_gate_proj, fused_up_proj, fused_down_proj)
1694 } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1695 let weight_block_size = match quantization_config {
1698 Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1699 _ => unreachable!(),
1700 };
1701
1702 let Some(weight_block_size) = weight_block_size else {
1703 candle_core::bail!(
1704 "Blockwise FP8 for per-expert format requires weight_block_size to be set."
1705 )
1706 };
1707 if weight_block_size.len() != 2 {
1708 candle_core::bail!(
1709 "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1710 );
1711 }
1712
1713 let mut gate_fp8_vec = Vec::new();
1714 let mut gate_scale_vec = Vec::new();
1715 let mut up_fp8_vec = Vec::new();
1716 let mut up_scale_vec = Vec::new();
1717 let mut down_fp8_vec = Vec::new();
1718 let mut down_scale_vec = Vec::new();
1719
1720 for i in 0..num_experts {
1721 let expert_vb = experts_vb.pp(i);
1722
1723 let gate_fp8 = expert_vb.get_with_hints_dtype(
1725 (moe_intermediate_size, hidden_size),
1726 "gate_proj.weight",
1727 Default::default(),
1728 candle_core::DType::F8E4M3,
1729 )?;
1730 let gate_scale = expert_vb.get_with_hints_dtype(
1731 (
1732 moe_intermediate_size.div_ceil(weight_block_size[0]),
1733 hidden_size.div_ceil(weight_block_size[1]),
1734 ),
1735 "gate_proj.weight_scale_inv",
1736 Default::default(),
1737 candle_core::DType::F32,
1738 )?;
1739
1740 let up_fp8 = expert_vb.get_with_hints_dtype(
1741 (moe_intermediate_size, hidden_size),
1742 "up_proj.weight",
1743 Default::default(),
1744 candle_core::DType::F8E4M3,
1745 )?;
1746 let up_scale = expert_vb.get_with_hints_dtype(
1747 (
1748 moe_intermediate_size.div_ceil(weight_block_size[0]),
1749 hidden_size.div_ceil(weight_block_size[1]),
1750 ),
1751 "up_proj.weight_scale_inv",
1752 Default::default(),
1753 candle_core::DType::F32,
1754 )?;
1755
1756 let down_fp8 = expert_vb.get_with_hints_dtype(
1757 (hidden_size, moe_intermediate_size),
1758 "down_proj.weight",
1759 Default::default(),
1760 candle_core::DType::F8E4M3,
1761 )?;
1762 let down_scale = expert_vb.get_with_hints_dtype(
1763 (
1764 hidden_size.div_ceil(weight_block_size[0]),
1765 moe_intermediate_size.div_ceil(weight_block_size[1]),
1766 ),
1767 "down_proj.weight_scale_inv",
1768 Default::default(),
1769 candle_core::DType::F32,
1770 )?;
1771
1772 gate_fp8_vec.push(gate_fp8);
1773 gate_scale_vec.push(gate_scale);
1774 up_fp8_vec.push(up_fp8);
1775 up_scale_vec.push(up_scale);
1776 down_fp8_vec.push(down_fp8);
1777 down_scale_vec.push(down_scale);
1778 }
1779
1780 let gate_fp8 = Tensor::stack(&gate_fp8_vec, 0)?;
1782 let gate_scale = Tensor::stack(&gate_scale_vec, 0)?;
1783 let up_fp8 = Tensor::stack(&up_fp8_vec, 0)?;
1784 let up_scale = Tensor::stack(&up_scale_vec, 0)?;
1785 let down_fp8 = Tensor::stack(&down_fp8_vec, 0)?;
1786 let down_scale = Tensor::stack(&down_scale_vec, 0)?;
1787
1788 let fused_gate_proj =
1790 blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1791 let fused_up_proj =
1792 blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1793 let fused_down_proj =
1794 blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1795
1796 (fused_gate_proj, fused_up_proj, fused_down_proj)
1797 } else {
1798 let mut gate_proj_vec = Vec::new();
1800 let mut up_proj_vec = Vec::new();
1801 let mut down_proj_vec = Vec::new();
1802 for i in 0..num_experts {
1803 let expert_vb = experts_vb.pp(i);
1804 let gate_proj =
1805 expert_vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1806 let up_proj =
1807 expert_vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1808 let down_proj =
1809 expert_vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1810
1811 gate_proj_vec.push(gate_proj);
1812 up_proj_vec.push(up_proj);
1813 down_proj_vec.push(down_proj);
1814 }
1815
1816 let mut gate_proj: Arc<dyn QuantMethod> =
1817 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1818 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1819 ))?);
1820 let mut up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1821 QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&up_proj_vec, 0)?, None)),
1822 )?);
1823 let mut down_proj: Arc<dyn QuantMethod> =
1824 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1825 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1826 ))?);
1827 let expert0_vb = experts_vb.pp("0");
1829 gate_proj = apply_immediate_isq(gate_proj, expert0_vb.pp("gate_proj"))?;
1830 up_proj = apply_immediate_isq(up_proj, expert0_vb.pp("up_proj"))?;
1831 down_proj = apply_immediate_isq(down_proj, expert0_vb.pp("down_proj"))?;
1832
1833 (gate_proj, up_proj, down_proj)
1834 };
1835
1836 Ok(Self {
1837 fused_gate_proj,
1838 fused_up_proj,
1839 fused_down_proj,
1840 })
1841 }
1842}
1843
1844pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1846 if comm.world_size() == 1 {
1847 return Shard::default();
1848 }
1849
1850 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1854 comm.world_size() / total_num_kv_heads
1855 } else {
1856 return Shard::Simple {
1857 dim: 0,
1858 rank: comm.rank(),
1859 world_size: comm.world_size(),
1860 };
1861 };
1862
1863 let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1864 let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1865 Shard::Offset {
1866 dim: 0,
1867 offset: kv_shard_id * head_dim,
1868 len: head_dim,
1869 }
1870}
1871
1872pub fn compute_n_kv_groups(
1874 total_num_kv_heads: usize,
1875 num_attention_heads: usize,
1876 comm: &Comm,
1877) -> usize {
1878 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1879 comm.world_size() / total_num_kv_heads
1880 } else {
1881 1
1882 };
1883 if kv_replicate != 0 {
1884 (num_attention_heads / total_num_kv_heads) / kv_replicate
1885 } else {
1886 num_attention_heads / total_num_kv_heads
1887 }
1888}