1use std::sync::Arc;
2
3use candle_core::{Context, Device, Result, Tensor};
4use candle_nn::Linear;
5
6use crate::{
7 blockwise_fp8::blockwise_fp8_linear_b,
8 distributed,
9 gptq::gptq_linear,
10 lora::merge_lora_weights,
11 should_apply_immediate_isq,
12 utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
13 AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, QuantMethod,
14 QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, QuantizedSerdeType,
15 Shard, ShardedVarBuilder, UnquantLinear,
16};
17
18use super::{Comm, SumAllReduce};
19
20fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
21 Shard::Simple {
22 dim,
23 rank,
24 world_size,
25 }
26}
27
28#[derive(Debug)]
31pub struct RowParallelLayer {
32 weight: Arc<dyn QuantMethod>,
33 bias: Option<Tensor>,
34 all_reduce: distributed::SumAllReduce,
35}
36
37impl RowParallelLayer {
38 #[allow(clippy::new_ret_no_self)]
39 pub fn new(
40 in_dim: usize,
41 out_dim: usize,
42 config: &Option<QuantizedConfig>,
43 bias: bool,
44 comm: &Arc<crate::Comm>,
45 vb: ShardedVarBuilder,
46 ) -> Result<Arc<dyn QuantMethod>> {
47 let rank = comm.rank();
48 let world_size = comm.world_size();
49 let shard = shard(1, rank, world_size);
50
51 let base_vb = vb.clone();
52 let vb = if should_apply_immediate_isq(&vb) {
53 vb.set_device(Device::Cpu)
54 } else {
55 vb
56 };
57
58 let weight = if let Some(quant_conf) = &config {
59 if matches!(
61 quant_conf,
62 QuantizedConfig::GptqAwq { .. }
63 | QuantizedConfig::Bitsandbytes { .. }
64 | QuantizedConfig::Afq { .. }
65 ) && comm.world_size() != 1
66 {
67 candle_core::bail!(
68 "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
69 comm.world_size()
70 );
71 }
72
73 match quant_conf {
74 QuantizedConfig::GptqAwq { .. } => {
75 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
76 }
77 QuantizedConfig::Fp8 { .. } => {
78 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
80 }
81 QuantizedConfig::Bitsandbytes { .. } => {
82 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
83 }
84 QuantizedConfig::Afq { .. } => {
85 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
86 }
87 }
88 } else {
89 if !vb.contains_tensor("weight") {
91 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
92 Arc::new(layer) as Arc<dyn QuantMethod>
93 } else {
94 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
95 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
96
97 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
98 Linear::new(weight, None),
99 ))?;
100 Arc::new(layer) as Arc<dyn QuantMethod>
101 }
102 };
103
104 let bias = if bias && vb.contains_tensor("bias") {
106 Some(vb.get((out_dim,), "bias")?)
107 } else {
108 None
109 };
110
111 let this_unquant = Arc::new(Self {
112 weight,
113 bias,
114 all_reduce: distributed::SumAllReduce::new(comm),
115 });
116 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
117 Ok(this)
118 }
119}
120
121impl QuantMethod for RowParallelLayer {
122 fn new(_method: QuantMethodConfig) -> Result<Self>
123 where
124 Self: Sized,
125 {
126 candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
127 }
128
129 fn forward(&self, a: &Tensor) -> Result<Tensor> {
130 let mut xs = self.weight.forward(a)?;
131 xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
132 if let Some(bias) = &self.bias {
133 xs = xs.broadcast_add(bias)?;
134 }
135 Ok(xs)
136 }
137
138 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139 let weight = self.weight.add_delta_w(delta)?;
140 Ok(Arc::new(Self {
141 weight,
142 bias: self.bias.clone(),
143 all_reduce: self.all_reduce.clone(),
144 }))
145 }
146
147 fn dequantize_w(&self) -> Result<Tensor> {
148 self.weight.dequantize_w()
149 }
150
151 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
152 self.weight.dtype_and_device()
153 }
154
155 fn begin_track_stats(&mut self) -> Result<()> {
156 Arc::get_mut(&mut self.weight)
157 .context("Failed to get &mut to weight")?
158 .begin_track_stats()
159 }
160
161 fn end_track_stats(&self) -> Result<Tensor> {
162 self.weight.end_track_stats()
163 }
164
165 fn quantized_act_type(&self) -> Option<candle_core::DType> {
166 self.weight.quantized_act_type()
167 }
168
169 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
170 self.weight.unquant_weight_bias()
171 }
172
173 fn apply_isq(
174 self: Arc<Self>,
175 dtype: Option<crate::IsqType>,
176 device: candle_core::Device,
177 n_quantized: &std::sync::atomic::AtomicUsize,
178 imatrix_weight: Option<Vec<f32>>,
179 guard: QuantizeOntoGuard,
180 ) -> Result<Arc<dyn QuantMethod>> {
181 let weight =
182 self.weight
183 .clone()
184 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
185 let bias = match &self.bias {
186 Some(b) => {
187 let (dtype, device) = weight.dtype_and_device();
188 Some(b.to_device(&device)?.to_dtype(dtype)?)
189 }
190 None => None,
191 };
192 Ok(Arc::new(Self {
193 weight,
194 bias,
195 all_reduce: self.all_reduce.clone(),
196 }))
197 }
198
199 fn is_distributed(&self) -> Option<DistributedKind> {
200 Some(DistributedKind::RowParallel)
201 }
202}
203
204impl QuantizedSerde for RowParallelLayer {
205 fn isq_serde_supported(&self) -> bool {
206 self.weight.isq_serde_supported()
207 }
208 fn name(&self) -> &'static str {
209 self.weight.name()
210 }
211 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
212 self.weight.serialize_with_bias(self.bias.clone())
213 }
214 fn deserialize(
215 data: std::borrow::Cow<[u8]>,
216 device: &candle_core::Device,
217 comm: &Arc<crate::Comm>,
218 guard: QuantizeOntoGuard,
219 ) -> Result<Arc<dyn QuantMethod>>
220 where
221 Self: Sized,
222 {
223 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
225 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
226 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
227 QuantizedSerdeType::Unquant => {
228 UnquantLinear::deserialize_ext_bias(data, device, guard)?
229 }
230 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
231 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
232 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
233 };
234 Ok(Arc::new(Self {
235 weight,
236 bias,
237 all_reduce: SumAllReduce::new(comm),
238 }))
239 }
240}
241
242#[derive(Debug)]
243pub struct ColumnParallelLayer {
246 weight: Arc<dyn QuantMethod>,
247 bias: Option<Tensor>,
248}
249
250impl ColumnParallelLayer {
251 #[allow(clippy::new_ret_no_self)]
252 pub fn new_with_shard(
253 in_dim: usize,
254 out_dim: usize,
255 config: &Option<QuantizedConfig>,
256 bias: bool,
257 comm: &Arc<crate::Comm>,
258 shard: Shard,
259 vb: ShardedVarBuilder,
260 ) -> Result<Arc<dyn QuantMethod>> {
261 let base_vb = vb.clone();
262 let vb = if should_apply_immediate_isq(&vb) {
263 vb.set_device(Device::Cpu)
264 } else {
265 vb
266 };
267
268 let weight = if let Some(quant_conf) = &config {
269 if matches!(
271 quant_conf,
272 QuantizedConfig::GptqAwq { .. }
273 | QuantizedConfig::Bitsandbytes { .. }
274 | QuantizedConfig::Afq { .. }
275 ) && comm.world_size() != 1
276 {
277 candle_core::bail!(
278 "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
279 comm.world_size()
280 );
281 }
282
283 match quant_conf {
284 QuantizedConfig::GptqAwq { .. } => {
285 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
286 }
287 QuantizedConfig::Fp8 { .. } => {
288 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
290 }
291 QuantizedConfig::Bitsandbytes { .. } => {
292 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
293 }
294 QuantizedConfig::Afq { .. } => {
295 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
296 }
297 }
298 } else {
299 if !vb.contains_tensor("weight") {
301 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
302 Arc::new(layer) as Arc<dyn QuantMethod>
303 } else {
304 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
305 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
306
307 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
308 Linear::new(weight, None),
309 ))?;
310 Arc::new(layer) as Arc<dyn QuantMethod>
311 }
312 };
313
314 let bias = if bias && vb.contains_tensor("bias") {
316 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
317 } else {
318 None
319 };
320
321 let this_unquant = Arc::new(Self { weight, bias });
322 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
323 Ok(this)
324 }
325
326 #[allow(clippy::new_ret_no_self)]
327 pub fn new(
328 in_dim: usize,
329 out_dim: usize,
330 config: &Option<QuantizedConfig>,
331 bias: bool,
332 comm: &Arc<crate::Comm>,
333 vb: ShardedVarBuilder,
334 ) -> Result<Arc<dyn QuantMethod>> {
335 let rank = comm.rank();
336 let world_size = comm.world_size();
337 let shard = shard(0, rank, world_size);
338
339 Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
340 }
341}
342
343impl QuantMethod for ColumnParallelLayer {
344 fn new(_method: QuantMethodConfig) -> Result<Self>
345 where
346 Self: Sized,
347 {
348 candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
349 }
350
351 fn forward(&self, a: &Tensor) -> Result<Tensor> {
352 let mut xs = self.weight.forward(a)?;
353 if let Some(bias) = &self.bias {
354 xs = xs.broadcast_add(bias)?;
355 }
356 Ok(xs)
357 }
358
359 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
360 let weight = self.weight.add_delta_w(delta)?;
361 Ok(Arc::new(Self {
362 weight,
363 bias: self.bias.clone(),
364 }))
365 }
366
367 fn dequantize_w(&self) -> Result<Tensor> {
368 self.weight.dequantize_w()
369 }
370
371 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
372 self.weight.dtype_and_device()
373 }
374
375 fn begin_track_stats(&mut self) -> Result<()> {
376 Arc::get_mut(&mut self.weight)
377 .context("Failed to get &mut to weight")?
378 .begin_track_stats()
379 }
380
381 fn end_track_stats(&self) -> Result<Tensor> {
382 self.weight.end_track_stats()
383 }
384
385 fn quantized_act_type(&self) -> Option<candle_core::DType> {
386 self.weight.quantized_act_type()
387 }
388
389 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
390 self.weight.unquant_weight_bias()
391 }
392
393 fn apply_isq(
394 self: Arc<Self>,
395 dtype: Option<crate::IsqType>,
396 device: candle_core::Device,
397 n_quantized: &std::sync::atomic::AtomicUsize,
398 imatrix_weight: Option<Vec<f32>>,
399 guard: QuantizeOntoGuard,
400 ) -> Result<Arc<dyn QuantMethod>> {
401 let weight =
402 self.weight
403 .clone()
404 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
405 let bias = match &self.bias {
406 Some(b) => {
407 let (dtype, device) = weight.dtype_and_device();
408 Some(b.to_device(&device)?.to_dtype(dtype)?)
409 }
410 None => None,
411 };
412 Ok(Arc::new(Self { weight, bias }))
413 }
414
415 fn is_distributed(&self) -> Option<DistributedKind> {
416 Some(DistributedKind::ColumnParallel)
417 }
418}
419
420impl QuantizedSerde for ColumnParallelLayer {
421 fn isq_serde_supported(&self) -> bool {
422 self.weight.isq_serde_supported()
423 }
424 fn name(&self) -> &'static str {
425 self.weight.name()
426 }
427 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
428 self.weight.serialize_with_bias(self.bias.clone())
429 }
430 fn deserialize(
431 data: std::borrow::Cow<[u8]>,
432 device: &candle_core::Device,
433 _comm: &Arc<crate::Comm>,
434 guard: QuantizeOntoGuard,
435 ) -> Result<Arc<dyn QuantMethod>>
436 where
437 Self: Sized,
438 {
439 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
441 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
442 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
443 QuantizedSerdeType::Unquant => {
444 UnquantLinear::deserialize_ext_bias(data, device, guard)?
445 }
446 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
447 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
448 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
449 };
450 Ok(Arc::new(Self { weight, bias }))
451 }
452}
453
454#[derive(Debug)]
455pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
457
458impl ReplicatedLayer {
459 pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
460 let dev = lin.weight().device().clone();
461 let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
462 let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
463 Ok(this)
464 }
465
466 #[allow(clippy::new_ret_no_self)]
467 pub fn new(
468 in_dim: usize,
469 out_dim: usize,
470 config: &Option<QuantizedConfig>,
471 bias: bool,
472 vb: ShardedVarBuilder,
473 ) -> Result<Arc<dyn QuantMethod>> {
474 let base_vb = vb.clone();
475 let vb = if should_apply_immediate_isq(&vb) {
476 vb.set_device(Device::Cpu)
477 } else {
478 vb
479 };
480
481 let layer = if let Some(quant_conf) = &config {
482 match quant_conf {
483 QuantizedConfig::GptqAwq { .. } => {
484 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
485 }
486 QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
487 in_dim,
488 out_dim,
489 quant_conf,
490 bias,
491 Default::default(),
492 vb.clone(),
493 )?,
494 QuantizedConfig::Bitsandbytes { .. } => {
495 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
496 }
497 QuantizedConfig::Afq { .. } => {
498 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
499 }
500 }
501 } else {
502 if !vb.contains_tensor("weight") {
504 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
505 Arc::new(layer) as Arc<dyn QuantMethod>
506 } else {
507 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
508 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
509
510 let bias = if bias {
511 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
512 } else {
513 None
514 };
515 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
516 Linear::new(weight, bias),
517 ))?;
518 Arc::new(layer) as Arc<dyn QuantMethod>
519 }
520 };
521
522 let this_unquant = Arc::new(Self(layer));
523 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
524 Ok(this)
525 }
526}
527
528impl QuantMethod for ReplicatedLayer {
529 fn new(_method: QuantMethodConfig) -> Result<Self>
530 where
531 Self: Sized,
532 {
533 candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
534 }
535
536 fn forward(&self, a: &Tensor) -> Result<Tensor> {
537 self.0.forward(a)
538 }
539
540 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
541 self.0.add_delta_w(delta)
542 }
543
544 fn dequantize_w(&self) -> Result<Tensor> {
545 self.0.dequantize_w()
546 }
547
548 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
549 self.0.dtype_and_device()
550 }
551
552 fn begin_track_stats(&mut self) -> Result<()> {
553 Arc::get_mut(&mut self.0)
554 .context("Failed to get &mut to weight")?
555 .begin_track_stats()
556 }
557
558 fn end_track_stats(&self) -> Result<Tensor> {
559 self.0.end_track_stats()
560 }
561
562 fn quantized_act_type(&self) -> Option<candle_core::DType> {
563 self.0.quantized_act_type()
564 }
565
566 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
567 self.0.unquant_weight_bias()
568 }
569
570 fn apply_isq(
571 self: Arc<Self>,
572 dtype: Option<crate::IsqType>,
573 device: candle_core::Device,
574 n_quantized: &std::sync::atomic::AtomicUsize,
575 imatrix_weight: Option<Vec<f32>>,
576 guard: QuantizeOntoGuard,
577 ) -> Result<Arc<dyn QuantMethod>> {
578 self.0
579 .clone()
580 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
581 }
582
583 fn is_distributed(&self) -> Option<DistributedKind> {
584 Some(DistributedKind::Replicated)
585 }
586}
587
588impl QuantizedSerde for ReplicatedLayer {
589 fn isq_serde_supported(&self) -> bool {
590 self.0.isq_serde_supported()
591 }
592 fn name(&self) -> &'static str {
593 self.0.name()
594 }
595 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
596 self.0.serialize()
597 }
598 fn deserialize(
599 data: std::borrow::Cow<[u8]>,
600 device: &candle_core::Device,
601 comm: &Arc<crate::Comm>,
602 guard: QuantizeOntoGuard,
603 ) -> Result<Arc<dyn QuantMethod>>
604 where
605 Self: Sized,
606 {
607 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
609 let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
610 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
611 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
612 QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
613 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
614 QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
615 };
616 Ok(Arc::new(Self(deserialized)))
617 }
618}
619
620#[derive(Debug)]
621pub struct PackedExperts {
622 pub gate_proj: Vec<Arc<dyn QuantMethod>>,
623 pub up_proj: Vec<Arc<dyn QuantMethod>>,
624 pub down_proj: Vec<Arc<dyn QuantMethod>>,
625}
626
627impl PackedExperts {
628 #[allow(clippy::too_many_arguments)]
630 pub fn new(
631 num_local_experts: usize,
632 hidden_size: usize,
633 intermediate_size: usize,
634 config: &Option<QuantizedConfig>,
635 bias: bool,
636 comm: &Arc<crate::Comm>,
637 vb: ShardedVarBuilder,
638 ) -> Result<Self> {
639 if bias {
640 candle_core::bail!("PackedExperts does not support bias.");
641 }
642
643 let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
644 if comm.world_size() != 1 {
646 candle_core::bail!(
647 "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
648 comm.world_size()
649 );
650 }
651
652 match quant_conf {
653 QuantizedConfig::Afq { .. } => {
654 if !vb.contains_tensor("gate_up_proj")
655 || !vb.contains_tensor("gate_up_proj.weight")
656 {
657 candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
658 }
659
660 let base_vb = vb.clone();
661
662 let vb_gate_proj = if should_apply_immediate_isq(&vb) {
663 vb.pp("gate_proj").set_device(Device::Cpu)
664 } else {
665 vb.pp("gate_proj")
666 };
667 let vb_up_proj = if should_apply_immediate_isq(&vb) {
668 vb.pp("up_proj").set_device(Device::Cpu)
669 } else {
670 vb.pp("up_proj")
671 };
672 let vb_down_proj = if should_apply_immediate_isq(&vb) {
673 vb.pp("down_proj").set_device(Device::Cpu)
674 } else {
675 vb.pp("down_proj")
676 };
677 let mut gate_proj = AfqLayer::afq_packed_linear_b(
678 num_local_experts,
679 hidden_size,
680 intermediate_size,
681 quant_conf,
682 bias,
683 vb_gate_proj,
684 )?;
685 let mut up_proj = AfqLayer::afq_packed_linear_b(
686 num_local_experts,
687 hidden_size,
688 intermediate_size,
689 quant_conf,
690 bias,
691 vb_up_proj,
692 )?;
693 let mut down_proj = AfqLayer::afq_packed_linear_b(
694 num_local_experts,
695 intermediate_size,
696 hidden_size,
697 quant_conf,
698 bias,
699 vb_down_proj,
700 )?;
701
702 gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
703 up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
704 down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
705
706 (vec![gate_proj], vec![up_proj], vec![down_proj])
707 }
708 _ => candle_core::bail!(
709 "PackedExperts with quantization config only allows AFQ quantization"
710 ),
711 }
712 } else if !vb.contains_tensor("gate_up_proj") {
713 let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
715 let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
716 let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
717 for _ in 0..num_local_experts {
718 gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
719 us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
720 ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
721 }
722 (gs, us, ds)
723 } else {
724 let gate_up_block_size = intermediate_size / comm.world_size();
732 let gate_up_start = gate_up_block_size * comm.rank();
733
734 let shard_gate = Shard::Offset {
736 dim: 2,
737 offset: gate_up_start,
738 len: gate_up_block_size,
739 };
740 let shard_up = Shard::Offset {
741 dim: 2,
742 offset: intermediate_size + gate_up_start,
743 len: gate_up_block_size,
744 };
745 let shard_down = Shard::Simple {
746 dim: 1,
747 rank: comm.rank(),
748 world_size: comm.world_size(),
749 };
750
751 let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
752 vb.pp("gate_up_proj").set_device(Device::Cpu)
753 } else {
754 vb.pp("gate_up_proj")
755 };
756 let vb_down_proj = if should_apply_immediate_isq(&vb) {
757 vb.pp("down_proj").set_device(Device::Cpu)
758 } else {
759 vb.pp("down_proj")
760 };
761
762 let gate_proj = vb
763 .get_with_hints(
764 (num_local_experts, hidden_size, intermediate_size * 2),
765 "gate_up_proj",
766 shard_gate,
767 )?
768 .t()?
769 .contiguous()?;
770 let up_proj = vb
771 .get_with_hints(
772 (num_local_experts, hidden_size, intermediate_size * 2),
773 "gate_up_proj",
774 shard_up,
775 )?
776 .t()?
777 .contiguous()?;
778 let down_proj = vb
779 .get_with_hints(
780 (num_local_experts, intermediate_size, hidden_size),
781 "down_proj",
782 shard_down,
783 )?
784 .t()?
785 .contiguous()?;
786
787 let gc = gate_proj.chunk(num_local_experts, 0)?;
788 let uc = up_proj.chunk(num_local_experts, 0)?;
789 let dc = down_proj.chunk(num_local_experts, 0)?;
790 drop((gate_proj, up_proj, down_proj));
791
792 let mut gs = Vec::new();
793 let mut us = Vec::new();
794 let mut ds = Vec::new();
795 for ((mut gate_proj, mut up_proj), mut down_proj) in
796 gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
797 {
798 gate_proj = gate_proj.squeeze(0)?;
799 up_proj = up_proj.squeeze(0)?;
800 down_proj = down_proj.squeeze(0)?;
801 let gate_proj = merge_lora_weights(
802 &vb,
803 gate_proj,
804 hidden_size,
805 intermediate_size * 2,
806 shard_gate,
807 )?;
808 let up_proj =
809 merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
810 let down_proj =
811 merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
812
813 let mut gate_proj: Arc<dyn QuantMethod> =
814 Arc::new(<UnquantLinear as QuantMethod>::new(
815 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
816 )?);
817 gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
818 let mut up_proj: Arc<dyn QuantMethod> =
819 Arc::new(<UnquantLinear as QuantMethod>::new(
820 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
821 )?);
822 up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
823 let mut down_proj: Arc<dyn QuantMethod> =
824 Arc::new(<UnquantLinear as QuantMethod>::new(
825 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
826 )?);
827 down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
828 gs.push(gate_proj);
829 us.push(up_proj);
830 ds.push(down_proj);
831 }
832 (gs, us, ds)
833 };
834
835 Ok(Self {
836 gate_proj,
837 up_proj,
838 down_proj,
839 })
840 }
841}
842
843pub struct FusedExperts {
844 pub fused_gate_proj: Arc<dyn QuantMethod>,
845 pub fused_up_proj: Arc<dyn QuantMethod>,
846 pub fused_down_proj: Arc<dyn QuantMethod>,
847}
848
849impl FusedExperts {
850 pub fn new(
851 hidden_size: usize,
852 moe_intermediate_size: usize,
853 num_experts: usize,
854 quantization_config: &Option<QuantizedConfig>,
855 vb: ShardedVarBuilder,
856 ) -> Result<Self> {
857 if !vb.device().is_metal() {
858 candle_core::bail!("FastMoeMlp requires Metal.");
859 }
860
861 let (fused_gate_proj, fused_up_proj, fused_down_proj) =
862 if matches!(&quantization_config, Some(QuantizedConfig::Afq { .. })) {
863 let quantization_config = quantization_config.as_ref().unwrap();
864
865 let fused_gate_proj = AfqLayer::afq_packed_linear_b(
866 num_experts,
867 hidden_size,
868 moe_intermediate_size,
869 quantization_config,
870 false,
871 vb.pp("switch_mlp.gate_proj"),
872 )?;
873 let fused_up_proj = AfqLayer::afq_packed_linear_b(
874 num_experts,
875 hidden_size,
876 moe_intermediate_size,
877 quantization_config,
878 false,
879 vb.pp("switch_mlp.up_proj"),
880 )?;
881 let fused_down_proj = AfqLayer::afq_packed_linear_b(
882 num_experts,
883 moe_intermediate_size,
884 hidden_size,
885 quantization_config,
886 false,
887 vb.pp("switch_mlp.down_proj"),
888 )?;
889
890 (fused_gate_proj, fused_up_proj, fused_down_proj)
891 } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
892 let experts_vb = vb.pp("experts");
893 let mut gate_proj_vec = Vec::new();
894 let mut up_proj_vec = Vec::new();
895 let mut down_proj_vec = Vec::new();
896 for i in 0..num_experts {
897 let vb = experts_vb.pp(i);
898
899 let gate_proj = crate::linear_no_bias(
900 hidden_size,
901 moe_intermediate_size,
902 quantization_config,
903 vb.pp("gate_proj.weight"),
904 )?;
905 let up_proj = crate::linear_no_bias(
906 hidden_size,
907 moe_intermediate_size,
908 quantization_config,
909 vb.pp("up_proj.weight"),
910 )?;
911 let down_proj = crate::linear_no_bias(
912 moe_intermediate_size,
913 hidden_size,
914 quantization_config,
915 vb.pp("down_proj.weight"),
916 )?;
917
918 gate_proj_vec.push(gate_proj.dequantize_w()?);
919 up_proj_vec.push(up_proj.dequantize_w()?);
920 down_proj_vec.push(down_proj.dequantize_w()?);
921 }
922
923 let mut gate_proj: Arc<dyn QuantMethod> =
924 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
925 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
926 ))?);
927 let mut up_proj: Arc<dyn QuantMethod> =
928 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
929 Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
930 ))?);
931 let mut down_proj: Arc<dyn QuantMethod> =
932 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
933 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
934 ))?);
935 gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
936 up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
937 down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
938
939 (gate_proj, up_proj, down_proj)
940 } else {
941 let experts_vb = vb.pp("experts");
942 let mut gate_proj_vec = Vec::new();
943 let mut up_proj_vec = Vec::new();
944 let mut down_proj_vec = Vec::new();
945 for i in 0..num_experts {
946 let vb = experts_vb.pp(i);
947 let gate_proj =
948 vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
949 let up_proj = vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
950 let down_proj =
951 vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
952
953 gate_proj_vec.push(gate_proj);
954 up_proj_vec.push(up_proj);
955 down_proj_vec.push(down_proj);
956 }
957
958 let mut gate_proj: Arc<dyn QuantMethod> =
959 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
960 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
961 ))?);
962 let mut up_proj: Arc<dyn QuantMethod> =
963 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
964 Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
965 ))?);
966 let mut down_proj: Arc<dyn QuantMethod> =
967 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
968 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
969 ))?);
970 gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
971 up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
972 down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
973
974 (gate_proj, up_proj, down_proj)
975 };
976
977 Ok(Self {
978 fused_gate_proj,
979 fused_up_proj,
980 fused_down_proj,
981 })
982 }
983}
984
985pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
987 if comm.world_size() == 1 {
988 return Shard::default();
989 }
990
991 let kv_replicate = if comm.world_size() > total_num_kv_heads {
995 comm.world_size() / total_num_kv_heads
996 } else {
997 return Shard::Simple {
998 dim: 0,
999 rank: comm.rank(),
1000 world_size: comm.world_size(),
1001 };
1002 };
1003
1004 let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1005 let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1006 Shard::Offset {
1007 dim: 0,
1008 offset: kv_shard_id * head_dim,
1009 len: head_dim,
1010 }
1011}
1012
1013pub fn compute_n_kv_groups(
1015 total_num_kv_heads: usize,
1016 num_attention_heads: usize,
1017 comm: &Comm,
1018) -> usize {
1019 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1020 comm.world_size() / total_num_kv_heads
1021 } else {
1022 1
1023 };
1024 if kv_replicate != 0 {
1025 (num_attention_heads / total_num_kv_heads) / kv_replicate
1026 } else {
1027 num_attention_heads / total_num_kv_heads
1028 }
1029}