1use std::sync::Arc;
2
3use candle_core::{Context, Device, Result, Tensor};
4use candle_nn::Linear;
5
6use crate::{
7 blockwise_fp8::blockwise_fp8_linear_b,
8 distributed,
9 gptq::gptq_linear,
10 lora::merge_lora_weights,
11 should_apply_immediate_isq,
12 utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
13 AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, QuantMethod,
14 QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, QuantizedSerdeType,
15 Shard, ShardedVarBuilder, UnquantLinear,
16};
17
18use super::{Comm, SumAllReduce};
19
20fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
21 Shard::Simple {
22 dim,
23 rank,
24 world_size,
25 }
26}
27
28#[derive(Debug)]
31pub struct RowParallelLayer {
32 weight: Arc<dyn QuantMethod>,
33 bias: Option<Tensor>,
34 all_reduce: distributed::SumAllReduce,
35}
36
37impl RowParallelLayer {
38 #[allow(clippy::new_ret_no_self)]
39 pub fn new(
40 in_dim: usize,
41 out_dim: usize,
42 config: &Option<QuantizedConfig>,
43 bias: bool,
44 comm: &Arc<crate::Comm>,
45 vb: ShardedVarBuilder,
46 ) -> Result<Arc<dyn QuantMethod>> {
47 let rank = comm.rank();
48 let world_size = comm.world_size();
49 let shard = shard(1, rank, world_size);
50
51 let base_vb = vb.clone();
52 let vb = if should_apply_immediate_isq(&vb) {
53 vb.set_device(Device::Cpu)
54 } else {
55 vb
56 };
57
58 let weight = if let Some(quant_conf) = &config {
59 if matches!(
61 quant_conf,
62 QuantizedConfig::GptqAwq { .. }
63 | QuantizedConfig::Bitsandbytes { .. }
64 | QuantizedConfig::Afq { .. }
65 ) && comm.world_size() != 1
66 {
67 candle_core::bail!(
68 "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
69 comm.world_size()
70 );
71 }
72
73 match quant_conf {
74 QuantizedConfig::GptqAwq { .. } => {
75 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
76 }
77 QuantizedConfig::Fp8 { .. } => {
78 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
80 }
81 QuantizedConfig::Bitsandbytes { .. } => {
82 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
83 }
84 QuantizedConfig::Afq { .. } => {
85 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
86 }
87 }
88 } else {
89 if !vb.contains_tensor("weight") {
91 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
92 Arc::new(layer) as Arc<dyn QuantMethod>
93 } else {
94 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
95 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
96
97 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
98 Linear::new(weight, None),
99 ))?;
100 Arc::new(layer) as Arc<dyn QuantMethod>
101 }
102 };
103
104 let bias = if bias && vb.contains_tensor("bias") {
106 Some(vb.get((out_dim,), "bias")?)
107 } else {
108 None
109 };
110
111 let this_unquant = Arc::new(Self {
112 weight,
113 bias,
114 all_reduce: distributed::SumAllReduce::new(comm),
115 });
116 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
117 Ok(this)
118 }
119}
120
121impl QuantMethod for RowParallelLayer {
122 fn new(_method: QuantMethodConfig) -> Result<Self>
123 where
124 Self: Sized,
125 {
126 candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
127 }
128
129 fn forward(&self, a: &Tensor) -> Result<Tensor> {
130 let mut xs = self.weight.forward(a)?;
131 xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
132 if let Some(bias) = &self.bias {
133 xs = xs.broadcast_add(bias)?;
134 }
135 Ok(xs)
136 }
137
138 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139 let weight = self.weight.add_delta_w(delta)?;
140 Ok(Arc::new(Self {
141 weight,
142 bias: self.bias.clone(),
143 all_reduce: self.all_reduce.clone(),
144 }))
145 }
146
147 fn dequantize_w(&self) -> Result<Tensor> {
148 self.weight.dequantize_w()
149 }
150
151 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
152 self.weight.dtype_and_device()
153 }
154
155 fn begin_track_stats(&mut self) -> Result<()> {
156 Arc::get_mut(&mut self.weight)
157 .context("Failed to get &mut to weight")?
158 .begin_track_stats()
159 }
160
161 fn end_track_stats(&self) -> Result<Tensor> {
162 self.weight.end_track_stats()
163 }
164
165 fn quantized_act_type(&self) -> Option<candle_core::DType> {
166 self.weight.quantized_act_type()
167 }
168
169 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
170 self.weight.unquant_weight_bias()
171 }
172
173 fn apply_isq(
174 self: Arc<Self>,
175 dtype: Option<crate::IsqType>,
176 device: candle_core::Device,
177 n_quantized: &std::sync::atomic::AtomicUsize,
178 imatrix_weight: Option<Vec<f32>>,
179 guard: QuantizeOntoGuard,
180 ) -> Result<Arc<dyn QuantMethod>> {
181 let weight =
182 self.weight
183 .clone()
184 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
185 let bias = match &self.bias {
186 Some(b) => {
187 let (dtype, device) = weight.dtype_and_device();
188 Some(b.to_device(&device)?.to_dtype(dtype)?)
189 }
190 None => None,
191 };
192 Ok(Arc::new(Self {
193 weight,
194 bias,
195 all_reduce: self.all_reduce.clone(),
196 }))
197 }
198
199 fn is_distributed(&self) -> Option<DistributedKind> {
200 Some(DistributedKind::RowParallel)
201 }
202}
203
204impl QuantizedSerde for RowParallelLayer {
205 fn isq_serde_supported(&self) -> bool {
206 self.weight.isq_serde_supported()
207 }
208 fn name(&self) -> &'static str {
209 self.weight.name()
210 }
211 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
212 self.weight.serialize_with_bias(self.bias.clone())
213 }
214 fn deserialize(
215 data: std::borrow::Cow<[u8]>,
216 device: &candle_core::Device,
217 comm: &Arc<crate::Comm>,
218 guard: QuantizeOntoGuard,
219 ) -> Result<Arc<dyn QuantMethod>>
220 where
221 Self: Sized,
222 {
223 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
225 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
226 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
227 QuantizedSerdeType::Unquant => {
228 UnquantLinear::deserialize_ext_bias(data, device, guard)?
229 }
230 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
231 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
232 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
233 };
234 Ok(Arc::new(Self {
235 weight,
236 bias,
237 all_reduce: SumAllReduce::new(comm),
238 }))
239 }
240}
241
242#[derive(Debug)]
243pub struct ColumnParallelLayer {
246 weight: Arc<dyn QuantMethod>,
247 bias: Option<Tensor>,
248}
249
250impl ColumnParallelLayer {
251 #[allow(clippy::new_ret_no_self)]
252 pub fn new_with_shard(
253 in_dim: usize,
254 out_dim: usize,
255 config: &Option<QuantizedConfig>,
256 bias: bool,
257 comm: &Arc<crate::Comm>,
258 shard: Shard,
259 vb: ShardedVarBuilder,
260 ) -> Result<Arc<dyn QuantMethod>> {
261 let base_vb = vb.clone();
262 let vb = if should_apply_immediate_isq(&vb) {
263 vb.set_device(Device::Cpu)
264 } else {
265 vb
266 };
267
268 let weight = if let Some(quant_conf) = &config {
269 if matches!(
271 quant_conf,
272 QuantizedConfig::GptqAwq { .. }
273 | QuantizedConfig::Bitsandbytes { .. }
274 | QuantizedConfig::Afq { .. }
275 ) && comm.world_size() != 1
276 {
277 candle_core::bail!(
278 "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
279 comm.world_size()
280 );
281 }
282
283 match quant_conf {
284 QuantizedConfig::GptqAwq { .. } => {
285 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
286 }
287 QuantizedConfig::Fp8 { .. } => {
288 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
290 }
291 QuantizedConfig::Bitsandbytes { .. } => {
292 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
293 }
294 QuantizedConfig::Afq { .. } => {
295 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
296 }
297 }
298 } else {
299 if !vb.contains_tensor("weight") {
301 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
302 Arc::new(layer) as Arc<dyn QuantMethod>
303 } else {
304 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
305 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
306
307 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
308 Linear::new(weight, None),
309 ))?;
310 Arc::new(layer) as Arc<dyn QuantMethod>
311 }
312 };
313
314 let bias = if bias && vb.contains_tensor("bias") {
316 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
317 } else {
318 None
319 };
320
321 let this_unquant = Arc::new(Self { weight, bias });
322 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
323 Ok(this)
324 }
325
326 #[allow(clippy::new_ret_no_self)]
327 pub fn new(
328 in_dim: usize,
329 out_dim: usize,
330 config: &Option<QuantizedConfig>,
331 bias: bool,
332 comm: &Arc<crate::Comm>,
333 vb: ShardedVarBuilder,
334 ) -> Result<Arc<dyn QuantMethod>> {
335 let rank = comm.rank();
336 let world_size = comm.world_size();
337 let shard = shard(0, rank, world_size);
338
339 Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
340 }
341
342 pub fn new_merged(
343 in_dim: usize,
344 out_dim: usize,
345 chunks: usize,
346 config: &Option<QuantizedConfig>,
347 bias: bool,
348 comm: &Arc<crate::Comm>,
349 vb: ShardedVarBuilder,
350 ) -> Result<Vec<Arc<dyn QuantMethod>>> {
351 let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
352 for chunk_idx in 0..chunks {
353 let layer = ColumnParallelLayer::new_with_shard(
354 in_dim,
355 out_dim,
356 config,
357 bias,
358 comm,
359 shard(
360 0,
361 chunk_idx * comm.world_size() + comm.rank(),
362 chunks * comm.world_size(),
363 ),
364 vb.clone(),
365 )?;
366 vec_layers.push(layer);
367 }
368 Ok(vec_layers)
369 }
370}
371
372impl QuantMethod for ColumnParallelLayer {
373 fn new(_method: QuantMethodConfig) -> Result<Self>
374 where
375 Self: Sized,
376 {
377 candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
378 }
379
380 fn forward(&self, a: &Tensor) -> Result<Tensor> {
381 let mut xs = self.weight.forward(a)?;
382 if let Some(bias) = &self.bias {
383 xs = xs.broadcast_add(bias)?;
384 }
385 Ok(xs)
386 }
387
388 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
389 let weight = self.weight.add_delta_w(delta)?;
390 Ok(Arc::new(Self {
391 weight,
392 bias: self.bias.clone(),
393 }))
394 }
395
396 fn dequantize_w(&self) -> Result<Tensor> {
397 self.weight.dequantize_w()
398 }
399
400 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
401 self.weight.dtype_and_device()
402 }
403
404 fn begin_track_stats(&mut self) -> Result<()> {
405 Arc::get_mut(&mut self.weight)
406 .context("Failed to get &mut to weight")?
407 .begin_track_stats()
408 }
409
410 fn end_track_stats(&self) -> Result<Tensor> {
411 self.weight.end_track_stats()
412 }
413
414 fn quantized_act_type(&self) -> Option<candle_core::DType> {
415 self.weight.quantized_act_type()
416 }
417
418 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
419 self.weight.unquant_weight_bias()
420 }
421
422 fn apply_isq(
423 self: Arc<Self>,
424 dtype: Option<crate::IsqType>,
425 device: candle_core::Device,
426 n_quantized: &std::sync::atomic::AtomicUsize,
427 imatrix_weight: Option<Vec<f32>>,
428 guard: QuantizeOntoGuard,
429 ) -> Result<Arc<dyn QuantMethod>> {
430 let weight =
431 self.weight
432 .clone()
433 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
434 let bias = match &self.bias {
435 Some(b) => {
436 let (dtype, device) = weight.dtype_and_device();
437 Some(b.to_device(&device)?.to_dtype(dtype)?)
438 }
439 None => None,
440 };
441 Ok(Arc::new(Self { weight, bias }))
442 }
443
444 fn is_distributed(&self) -> Option<DistributedKind> {
445 Some(DistributedKind::ColumnParallel)
446 }
447}
448
449impl QuantizedSerde for ColumnParallelLayer {
450 fn isq_serde_supported(&self) -> bool {
451 self.weight.isq_serde_supported()
452 }
453 fn name(&self) -> &'static str {
454 self.weight.name()
455 }
456 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
457 self.weight.serialize_with_bias(self.bias.clone())
458 }
459 fn deserialize(
460 data: std::borrow::Cow<[u8]>,
461 device: &candle_core::Device,
462 _comm: &Arc<crate::Comm>,
463 guard: QuantizeOntoGuard,
464 ) -> Result<Arc<dyn QuantMethod>>
465 where
466 Self: Sized,
467 {
468 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
470 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
471 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
472 QuantizedSerdeType::Unquant => {
473 UnquantLinear::deserialize_ext_bias(data, device, guard)?
474 }
475 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
476 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
477 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
478 };
479 Ok(Arc::new(Self { weight, bias }))
480 }
481}
482
483#[derive(Debug)]
484pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
486
487impl ReplicatedLayer {
488 pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
489 let dev = lin.weight().device().clone();
490 let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
491 let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
492 Ok(this)
493 }
494
495 #[allow(clippy::new_ret_no_self)]
496 pub fn new(
497 in_dim: usize,
498 out_dim: usize,
499 config: &Option<QuantizedConfig>,
500 bias: bool,
501 vb: ShardedVarBuilder,
502 ) -> Result<Arc<dyn QuantMethod>> {
503 let base_vb = vb.clone();
504 let vb = if should_apply_immediate_isq(&vb) {
505 vb.set_device(Device::Cpu)
506 } else {
507 vb
508 };
509
510 let layer = if let Some(quant_conf) = &config {
511 match quant_conf {
512 QuantizedConfig::GptqAwq { .. } => {
513 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
514 }
515 QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
516 in_dim,
517 out_dim,
518 quant_conf,
519 bias,
520 Default::default(),
521 vb.clone(),
522 )?,
523 QuantizedConfig::Bitsandbytes { .. } => {
524 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
525 }
526 QuantizedConfig::Afq { .. } => {
527 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
528 }
529 }
530 } else {
531 if !vb.contains_tensor("weight") {
533 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
534 Arc::new(layer) as Arc<dyn QuantMethod>
535 } else {
536 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
537 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
538
539 let bias = if bias {
540 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
541 } else {
542 None
543 };
544 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
545 Linear::new(weight, bias),
546 ))?;
547 Arc::new(layer) as Arc<dyn QuantMethod>
548 }
549 };
550
551 let this_unquant = Arc::new(Self(layer));
552 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
553 Ok(this)
554 }
555}
556
557impl QuantMethod for ReplicatedLayer {
558 fn new(_method: QuantMethodConfig) -> Result<Self>
559 where
560 Self: Sized,
561 {
562 candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
563 }
564
565 fn forward(&self, a: &Tensor) -> Result<Tensor> {
566 self.0.forward(a)
567 }
568
569 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
570 self.0.add_delta_w(delta)
571 }
572
573 fn dequantize_w(&self) -> Result<Tensor> {
574 self.0.dequantize_w()
575 }
576
577 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
578 self.0.dtype_and_device()
579 }
580
581 fn begin_track_stats(&mut self) -> Result<()> {
582 Arc::get_mut(&mut self.0)
583 .context("Failed to get &mut to weight")?
584 .begin_track_stats()
585 }
586
587 fn end_track_stats(&self) -> Result<Tensor> {
588 self.0.end_track_stats()
589 }
590
591 fn quantized_act_type(&self) -> Option<candle_core::DType> {
592 self.0.quantized_act_type()
593 }
594
595 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
596 self.0.unquant_weight_bias()
597 }
598
599 fn apply_isq(
600 self: Arc<Self>,
601 dtype: Option<crate::IsqType>,
602 device: candle_core::Device,
603 n_quantized: &std::sync::atomic::AtomicUsize,
604 imatrix_weight: Option<Vec<f32>>,
605 guard: QuantizeOntoGuard,
606 ) -> Result<Arc<dyn QuantMethod>> {
607 self.0
608 .clone()
609 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
610 }
611
612 fn is_distributed(&self) -> Option<DistributedKind> {
613 Some(DistributedKind::Replicated)
614 }
615}
616
617impl QuantizedSerde for ReplicatedLayer {
618 fn isq_serde_supported(&self) -> bool {
619 self.0.isq_serde_supported()
620 }
621 fn name(&self) -> &'static str {
622 self.0.name()
623 }
624 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
625 self.0.serialize()
626 }
627 fn deserialize(
628 data: std::borrow::Cow<[u8]>,
629 device: &candle_core::Device,
630 comm: &Arc<crate::Comm>,
631 guard: QuantizeOntoGuard,
632 ) -> Result<Arc<dyn QuantMethod>>
633 where
634 Self: Sized,
635 {
636 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
638 let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
639 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
640 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
641 QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
642 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
643 QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
644 };
645 Ok(Arc::new(Self(deserialized)))
646 }
647}
648
649#[derive(Debug)]
650pub struct PackedExperts {
651 pub gate_proj: Vec<Arc<dyn QuantMethod>>,
652 pub up_proj: Vec<Arc<dyn QuantMethod>>,
653 pub down_proj: Vec<Arc<dyn QuantMethod>>,
654}
655
656impl PackedExperts {
657 #[allow(clippy::too_many_arguments)]
659 pub fn new(
660 num_local_experts: usize,
661 hidden_size: usize,
662 intermediate_size: usize,
663 config: &Option<QuantizedConfig>,
664 bias: bool,
665 comm: &Arc<crate::Comm>,
666 vb: ShardedVarBuilder,
667 ) -> Result<Self> {
668 if bias {
669 candle_core::bail!("PackedExperts does not support bias.");
670 }
671
672 let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
673 if comm.world_size() != 1 {
675 candle_core::bail!(
676 "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
677 comm.world_size()
678 );
679 }
680
681 match quant_conf {
682 QuantizedConfig::Afq { .. } => {
683 if !vb.contains_tensor("gate_up_proj")
684 || !vb.contains_tensor("gate_up_proj.weight")
685 {
686 candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
687 }
688
689 let base_vb = vb.clone();
690
691 let vb_gate_proj = if should_apply_immediate_isq(&vb) {
692 vb.pp("gate_proj").set_device(Device::Cpu)
693 } else {
694 vb.pp("gate_proj")
695 };
696 let vb_up_proj = if should_apply_immediate_isq(&vb) {
697 vb.pp("up_proj").set_device(Device::Cpu)
698 } else {
699 vb.pp("up_proj")
700 };
701 let vb_down_proj = if should_apply_immediate_isq(&vb) {
702 vb.pp("down_proj").set_device(Device::Cpu)
703 } else {
704 vb.pp("down_proj")
705 };
706 let mut gate_proj = AfqLayer::afq_packed_linear_b(
707 num_local_experts,
708 hidden_size,
709 intermediate_size,
710 quant_conf,
711 bias,
712 vb_gate_proj,
713 )?;
714 let mut up_proj = AfqLayer::afq_packed_linear_b(
715 num_local_experts,
716 hidden_size,
717 intermediate_size,
718 quant_conf,
719 bias,
720 vb_up_proj,
721 )?;
722 let mut down_proj = AfqLayer::afq_packed_linear_b(
723 num_local_experts,
724 intermediate_size,
725 hidden_size,
726 quant_conf,
727 bias,
728 vb_down_proj,
729 )?;
730
731 gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
732 up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
733 down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
734
735 (vec![gate_proj], vec![up_proj], vec![down_proj])
736 }
737 _ => candle_core::bail!(
738 "PackedExperts with quantization config only allows AFQ quantization"
739 ),
740 }
741 } else if !vb.contains_tensor("gate_up_proj") {
742 let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
744 let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
745 let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
746 for _ in 0..num_local_experts {
747 gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
748 us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
749 ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
750 }
751 (gs, us, ds)
752 } else {
753 let gate_up_block_size = intermediate_size / comm.world_size();
761 let gate_up_start = gate_up_block_size * comm.rank();
762
763 let shard_gate = Shard::Offset {
765 dim: 2,
766 offset: gate_up_start,
767 len: gate_up_block_size,
768 };
769 let shard_up = Shard::Offset {
770 dim: 2,
771 offset: intermediate_size + gate_up_start,
772 len: gate_up_block_size,
773 };
774 let shard_down = Shard::Simple {
775 dim: 1,
776 rank: comm.rank(),
777 world_size: comm.world_size(),
778 };
779
780 let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
781 vb.pp("gate_up_proj").set_device(Device::Cpu)
782 } else {
783 vb.pp("gate_up_proj")
784 };
785 let vb_down_proj = if should_apply_immediate_isq(&vb) {
786 vb.pp("down_proj").set_device(Device::Cpu)
787 } else {
788 vb.pp("down_proj")
789 };
790
791 let gate_proj = vb
792 .get_with_hints(
793 (num_local_experts, hidden_size, intermediate_size * 2),
794 "gate_up_proj",
795 shard_gate,
796 )?
797 .t()?
798 .contiguous()?;
799 let up_proj = vb
800 .get_with_hints(
801 (num_local_experts, hidden_size, intermediate_size * 2),
802 "gate_up_proj",
803 shard_up,
804 )?
805 .t()?
806 .contiguous()?;
807 let down_proj = vb
808 .get_with_hints(
809 (num_local_experts, intermediate_size, hidden_size),
810 "down_proj",
811 shard_down,
812 )?
813 .t()?
814 .contiguous()?;
815
816 let gc = gate_proj.chunk(num_local_experts, 0)?;
817 let uc = up_proj.chunk(num_local_experts, 0)?;
818 let dc = down_proj.chunk(num_local_experts, 0)?;
819 drop((gate_proj, up_proj, down_proj));
820
821 let mut gs = Vec::new();
822 let mut us = Vec::new();
823 let mut ds = Vec::new();
824 for ((mut gate_proj, mut up_proj), mut down_proj) in
825 gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
826 {
827 gate_proj = gate_proj.squeeze(0)?;
828 up_proj = up_proj.squeeze(0)?;
829 down_proj = down_proj.squeeze(0)?;
830 let gate_proj = merge_lora_weights(
831 &vb,
832 gate_proj,
833 hidden_size,
834 intermediate_size * 2,
835 shard_gate,
836 )?;
837 let up_proj =
838 merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
839 let down_proj =
840 merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
841
842 let mut gate_proj: Arc<dyn QuantMethod> =
843 Arc::new(<UnquantLinear as QuantMethod>::new(
844 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
845 )?);
846 gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
847 let mut up_proj: Arc<dyn QuantMethod> =
848 Arc::new(<UnquantLinear as QuantMethod>::new(
849 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
850 )?);
851 up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
852 let mut down_proj: Arc<dyn QuantMethod> =
853 Arc::new(<UnquantLinear as QuantMethod>::new(
854 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
855 )?);
856 down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
857 gs.push(gate_proj);
858 us.push(up_proj);
859 ds.push(down_proj);
860 }
861 (gs, us, ds)
862 };
863
864 Ok(Self {
865 gate_proj,
866 up_proj,
867 down_proj,
868 })
869 }
870}
871
872pub struct FusedExperts {
873 pub fused_gate_proj: Arc<dyn QuantMethod>,
874 pub fused_up_proj: Arc<dyn QuantMethod>,
875 pub fused_down_proj: Arc<dyn QuantMethod>,
876}
877
878impl FusedExperts {
879 pub fn new(
880 hidden_size: usize,
881 moe_intermediate_size: usize,
882 num_experts: usize,
883 quantization_config: &Option<QuantizedConfig>,
884 vb: ShardedVarBuilder,
885 ) -> Result<Self> {
886 if !vb.device().is_metal() {
887 candle_core::bail!("FastMoeMlp requires Metal.");
888 }
889
890 let (fused_gate_proj, fused_up_proj, fused_down_proj) =
891 if matches!(&quantization_config, Some(QuantizedConfig::Afq { .. })) {
892 let quantization_config = quantization_config.as_ref().unwrap();
893
894 let fused_gate_proj = AfqLayer::afq_packed_linear_b(
895 num_experts,
896 hidden_size,
897 moe_intermediate_size,
898 quantization_config,
899 false,
900 vb.pp("switch_mlp.gate_proj"),
901 )?;
902 let fused_up_proj = AfqLayer::afq_packed_linear_b(
903 num_experts,
904 hidden_size,
905 moe_intermediate_size,
906 quantization_config,
907 false,
908 vb.pp("switch_mlp.up_proj"),
909 )?;
910 let fused_down_proj = AfqLayer::afq_packed_linear_b(
911 num_experts,
912 moe_intermediate_size,
913 hidden_size,
914 quantization_config,
915 false,
916 vb.pp("switch_mlp.down_proj"),
917 )?;
918
919 (fused_gate_proj, fused_up_proj, fused_down_proj)
920 } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
921 let experts_vb = vb.pp("experts");
922 let mut gate_proj_vec = Vec::new();
923 let mut up_proj_vec = Vec::new();
924 let mut down_proj_vec = Vec::new();
925 for i in 0..num_experts {
926 let vb = experts_vb.pp(i);
927
928 let gate_proj = crate::linear_no_bias(
929 hidden_size,
930 moe_intermediate_size,
931 quantization_config,
932 vb.pp("gate_proj.weight"),
933 )?;
934 let up_proj = crate::linear_no_bias(
935 hidden_size,
936 moe_intermediate_size,
937 quantization_config,
938 vb.pp("up_proj.weight"),
939 )?;
940 let down_proj = crate::linear_no_bias(
941 moe_intermediate_size,
942 hidden_size,
943 quantization_config,
944 vb.pp("down_proj.weight"),
945 )?;
946
947 gate_proj_vec.push(gate_proj.dequantize_w()?);
948 up_proj_vec.push(up_proj.dequantize_w()?);
949 down_proj_vec.push(down_proj.dequantize_w()?);
950 }
951
952 let mut gate_proj: Arc<dyn QuantMethod> =
953 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
954 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
955 ))?);
956 let mut up_proj: Arc<dyn QuantMethod> =
957 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
958 Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
959 ))?);
960 let mut down_proj: Arc<dyn QuantMethod> =
961 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
962 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
963 ))?);
964 gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
965 up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
966 down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
967
968 (gate_proj, up_proj, down_proj)
969 } else {
970 let experts_vb = vb.pp("experts");
971 let mut gate_proj_vec = Vec::new();
972 let mut up_proj_vec = Vec::new();
973 let mut down_proj_vec = Vec::new();
974 for i in 0..num_experts {
975 let vb = experts_vb.pp(i);
976 let gate_proj =
977 vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
978 let up_proj = vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
979 let down_proj =
980 vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
981
982 gate_proj_vec.push(gate_proj);
983 up_proj_vec.push(up_proj);
984 down_proj_vec.push(down_proj);
985 }
986
987 let mut gate_proj: Arc<dyn QuantMethod> =
988 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
989 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
990 ))?);
991 let mut up_proj: Arc<dyn QuantMethod> =
992 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
993 Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
994 ))?);
995 let mut down_proj: Arc<dyn QuantMethod> =
996 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
997 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
998 ))?);
999 gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1000 up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1001 down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1002
1003 (gate_proj, up_proj, down_proj)
1004 };
1005
1006 Ok(Self {
1007 fused_gate_proj,
1008 fused_up_proj,
1009 fused_down_proj,
1010 })
1011 }
1012}
1013
1014pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1016 if comm.world_size() == 1 {
1017 return Shard::default();
1018 }
1019
1020 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1024 comm.world_size() / total_num_kv_heads
1025 } else {
1026 return Shard::Simple {
1027 dim: 0,
1028 rank: comm.rank(),
1029 world_size: comm.world_size(),
1030 };
1031 };
1032
1033 let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1034 let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1035 Shard::Offset {
1036 dim: 0,
1037 offset: kv_shard_id * head_dim,
1038 len: head_dim,
1039 }
1040}
1041
1042pub fn compute_n_kv_groups(
1044 total_num_kv_heads: usize,
1045 num_attention_heads: usize,
1046 comm: &Comm,
1047) -> usize {
1048 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1049 comm.world_size() / total_num_kv_heads
1050 } else {
1051 1
1052 };
1053 if kv_replicate != 0 {
1054 (num_attention_heads / total_num_kv_heads) / kv_replicate
1055 } else {
1056 num_attention_heads / total_num_kv_heads
1057 }
1058}