mistralrs_quant/distributed/
layers.rs

1use std::sync::Arc;
2
3use candle_core::{Context, Device, Result, Tensor};
4use candle_nn::Linear;
5
6use crate::{
7    blockwise_fp8::blockwise_fp8_linear_b,
8    distributed,
9    gptq::gptq_linear,
10    lora::merge_lora_weights,
11    should_apply_immediate_isq,
12    utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
13    AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, QuantMethod,
14    QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, QuantizedSerdeType,
15    Shard, ShardedVarBuilder, UnquantLinear,
16};
17
18use super::{Comm, SumAllReduce};
19
20fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
21    Shard::Simple {
22        dim,
23        rank,
24        world_size,
25    }
26}
27
28/// This layer has a weight that is parallelized along the input dimension,
29/// returning the "full" output dimension.
30#[derive(Debug)]
31pub struct RowParallelLayer {
32    weight: Arc<dyn QuantMethod>,
33    bias: Option<Tensor>,
34    all_reduce: distributed::SumAllReduce,
35}
36
37impl RowParallelLayer {
38    #[allow(clippy::new_ret_no_self)]
39    pub fn new(
40        in_dim: usize,
41        out_dim: usize,
42        config: &Option<QuantizedConfig>,
43        bias: bool,
44        comm: &Arc<crate::Comm>,
45        vb: ShardedVarBuilder,
46    ) -> Result<Arc<dyn QuantMethod>> {
47        let rank = comm.rank();
48        let world_size = comm.world_size();
49        let shard = shard(1, rank, world_size);
50
51        let base_vb = vb.clone();
52        let vb = if should_apply_immediate_isq(&vb) {
53            vb.set_device(Device::Cpu)
54        } else {
55            vb
56        };
57
58        let weight = if let Some(quant_conf) = &config {
59            // GPTQ and BNB do not support tensor parallelism
60            if matches!(
61                quant_conf,
62                QuantizedConfig::GptqAwq { .. }
63                    | QuantizedConfig::Bitsandbytes { .. }
64                    | QuantizedConfig::Afq { .. }
65            ) && comm.world_size() != 1
66            {
67                candle_core::bail!(
68                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
69                    comm.world_size()
70                );
71            }
72
73            match quant_conf {
74                QuantizedConfig::GptqAwq { .. } => {
75                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
76                }
77                QuantizedConfig::Fp8 { .. } => {
78                    // NOTE: no bias for fp8 as it might be parallelized
79                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
80                }
81                QuantizedConfig::Bitsandbytes { .. } => {
82                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
83                }
84                QuantizedConfig::Afq { .. } => {
85                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
86                }
87            }
88        } else {
89            // Handle the case where the layer is dummy (no tensors)
90            if !vb.contains_tensor("weight") {
91                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
92                Arc::new(layer) as Arc<dyn QuantMethod>
93            } else {
94                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
95                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
96
97                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
98                    Linear::new(weight, None),
99                ))?;
100                Arc::new(layer) as Arc<dyn QuantMethod>
101            }
102        };
103
104        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
105        let bias = if bias && vb.contains_tensor("bias") {
106            Some(vb.get((out_dim,), "bias")?)
107        } else {
108            None
109        };
110
111        let this_unquant = Arc::new(Self {
112            weight,
113            bias,
114            all_reduce: distributed::SumAllReduce::new(comm),
115        });
116        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
117        Ok(this)
118    }
119}
120
121impl QuantMethod for RowParallelLayer {
122    fn new(_method: QuantMethodConfig) -> Result<Self>
123    where
124        Self: Sized,
125    {
126        candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
127    }
128
129    fn forward(&self, a: &Tensor) -> Result<Tensor> {
130        let mut xs = self.weight.forward(a)?;
131        xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
132        if let Some(bias) = &self.bias {
133            xs = xs.broadcast_add(bias)?;
134        }
135        Ok(xs)
136    }
137
138    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139        let weight = self.weight.add_delta_w(delta)?;
140        Ok(Arc::new(Self {
141            weight,
142            bias: self.bias.clone(),
143            all_reduce: self.all_reduce.clone(),
144        }))
145    }
146
147    fn dequantize_w(&self) -> Result<Tensor> {
148        self.weight.dequantize_w()
149    }
150
151    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
152        self.weight.dtype_and_device()
153    }
154
155    fn begin_track_stats(&mut self) -> Result<()> {
156        Arc::get_mut(&mut self.weight)
157            .context("Failed to get &mut to weight")?
158            .begin_track_stats()
159    }
160
161    fn end_track_stats(&self) -> Result<Tensor> {
162        self.weight.end_track_stats()
163    }
164
165    fn quantized_act_type(&self) -> Option<candle_core::DType> {
166        self.weight.quantized_act_type()
167    }
168
169    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
170        self.weight.unquant_weight_bias()
171    }
172
173    fn apply_isq(
174        self: Arc<Self>,
175        dtype: Option<crate::IsqType>,
176        device: candle_core::Device,
177        n_quantized: &std::sync::atomic::AtomicUsize,
178        imatrix_weight: Option<Vec<f32>>,
179        guard: QuantizeOntoGuard,
180    ) -> Result<Arc<dyn QuantMethod>> {
181        let weight =
182            self.weight
183                .clone()
184                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
185        let bias = match &self.bias {
186            Some(b) => {
187                let (dtype, device) = weight.dtype_and_device();
188                Some(b.to_device(&device)?.to_dtype(dtype)?)
189            }
190            None => None,
191        };
192        Ok(Arc::new(Self {
193            weight,
194            bias,
195            all_reduce: self.all_reduce.clone(),
196        }))
197    }
198
199    fn is_distributed(&self) -> Option<DistributedKind> {
200        Some(DistributedKind::RowParallel)
201    }
202}
203
204impl QuantizedSerde for RowParallelLayer {
205    fn isq_serde_supported(&self) -> bool {
206        self.weight.isq_serde_supported()
207    }
208    fn name(&self) -> &'static str {
209        self.weight.name()
210    }
211    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
212        self.weight.serialize_with_bias(self.bias.clone())
213    }
214    fn deserialize(
215        data: std::borrow::Cow<[u8]>,
216        device: &candle_core::Device,
217        comm: &Arc<crate::Comm>,
218        guard: QuantizeOntoGuard,
219    ) -> Result<Arc<dyn QuantMethod>>
220    where
221        Self: Sized,
222    {
223        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
224        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
225        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
226            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
227            QuantizedSerdeType::Unquant => {
228                UnquantLinear::deserialize_ext_bias(data, device, guard)?
229            }
230            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
231            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
232            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
233        };
234        Ok(Arc::new(Self {
235            weight,
236            bias,
237            all_reduce: SumAllReduce::new(comm),
238        }))
239    }
240}
241
242#[derive(Debug)]
243/// This layer has a weight that is parallelized along the output dimension,
244/// taking the "full" input dimension.
245pub struct ColumnParallelLayer {
246    weight: Arc<dyn QuantMethod>,
247    bias: Option<Tensor>,
248}
249
250impl ColumnParallelLayer {
251    #[allow(clippy::new_ret_no_self)]
252    pub fn new_with_shard(
253        in_dim: usize,
254        out_dim: usize,
255        config: &Option<QuantizedConfig>,
256        bias: bool,
257        comm: &Arc<crate::Comm>,
258        shard: Shard,
259        vb: ShardedVarBuilder,
260    ) -> Result<Arc<dyn QuantMethod>> {
261        let base_vb = vb.clone();
262        let vb = if should_apply_immediate_isq(&vb) {
263            vb.set_device(Device::Cpu)
264        } else {
265            vb
266        };
267
268        let weight = if let Some(quant_conf) = &config {
269            // GPTQ and BNB do not support tensor parallelism
270            if matches!(
271                quant_conf,
272                QuantizedConfig::GptqAwq { .. }
273                    | QuantizedConfig::Bitsandbytes { .. }
274                    | QuantizedConfig::Afq { .. }
275            ) && comm.world_size() != 1
276            {
277                candle_core::bail!(
278                    "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
279                    comm.world_size()
280                );
281            }
282
283            match quant_conf {
284                QuantizedConfig::GptqAwq { .. } => {
285                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
286                }
287                QuantizedConfig::Fp8 { .. } => {
288                    // NOTE: no bias for fp8 as it might be parallelized
289                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
290                }
291                QuantizedConfig::Bitsandbytes { .. } => {
292                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
293                }
294                QuantizedConfig::Afq { .. } => {
295                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
296                }
297            }
298        } else {
299            // Handle the case where the layer is dummy (no tensors)
300            if !vb.contains_tensor("weight") {
301                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
302                Arc::new(layer) as Arc<dyn QuantMethod>
303            } else {
304                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
305                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
306
307                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
308                    Linear::new(weight, None),
309                ))?;
310                Arc::new(layer) as Arc<dyn QuantMethod>
311            }
312        };
313
314        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
315        let bias = if bias && vb.contains_tensor("bias") {
316            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
317        } else {
318            None
319        };
320
321        let this_unquant = Arc::new(Self { weight, bias });
322        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
323        Ok(this)
324    }
325
326    #[allow(clippy::new_ret_no_self)]
327    pub fn new(
328        in_dim: usize,
329        out_dim: usize,
330        config: &Option<QuantizedConfig>,
331        bias: bool,
332        comm: &Arc<crate::Comm>,
333        vb: ShardedVarBuilder,
334    ) -> Result<Arc<dyn QuantMethod>> {
335        let rank = comm.rank();
336        let world_size = comm.world_size();
337        let shard = shard(0, rank, world_size);
338
339        Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
340    }
341}
342
343impl QuantMethod for ColumnParallelLayer {
344    fn new(_method: QuantMethodConfig) -> Result<Self>
345    where
346        Self: Sized,
347    {
348        candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
349    }
350
351    fn forward(&self, a: &Tensor) -> Result<Tensor> {
352        let mut xs = self.weight.forward(a)?;
353        if let Some(bias) = &self.bias {
354            xs = xs.broadcast_add(bias)?;
355        }
356        Ok(xs)
357    }
358
359    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
360        let weight = self.weight.add_delta_w(delta)?;
361        Ok(Arc::new(Self {
362            weight,
363            bias: self.bias.clone(),
364        }))
365    }
366
367    fn dequantize_w(&self) -> Result<Tensor> {
368        self.weight.dequantize_w()
369    }
370
371    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
372        self.weight.dtype_and_device()
373    }
374
375    fn begin_track_stats(&mut self) -> Result<()> {
376        Arc::get_mut(&mut self.weight)
377            .context("Failed to get &mut to weight")?
378            .begin_track_stats()
379    }
380
381    fn end_track_stats(&self) -> Result<Tensor> {
382        self.weight.end_track_stats()
383    }
384
385    fn quantized_act_type(&self) -> Option<candle_core::DType> {
386        self.weight.quantized_act_type()
387    }
388
389    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
390        self.weight.unquant_weight_bias()
391    }
392
393    fn apply_isq(
394        self: Arc<Self>,
395        dtype: Option<crate::IsqType>,
396        device: candle_core::Device,
397        n_quantized: &std::sync::atomic::AtomicUsize,
398        imatrix_weight: Option<Vec<f32>>,
399        guard: QuantizeOntoGuard,
400    ) -> Result<Arc<dyn QuantMethod>> {
401        let weight =
402            self.weight
403                .clone()
404                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
405        let bias = match &self.bias {
406            Some(b) => {
407                let (dtype, device) = weight.dtype_and_device();
408                Some(b.to_device(&device)?.to_dtype(dtype)?)
409            }
410            None => None,
411        };
412        Ok(Arc::new(Self { weight, bias }))
413    }
414
415    fn is_distributed(&self) -> Option<DistributedKind> {
416        Some(DistributedKind::ColumnParallel)
417    }
418}
419
420impl QuantizedSerde for ColumnParallelLayer {
421    fn isq_serde_supported(&self) -> bool {
422        self.weight.isq_serde_supported()
423    }
424    fn name(&self) -> &'static str {
425        self.weight.name()
426    }
427    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
428        self.weight.serialize_with_bias(self.bias.clone())
429    }
430    fn deserialize(
431        data: std::borrow::Cow<[u8]>,
432        device: &candle_core::Device,
433        _comm: &Arc<crate::Comm>,
434        guard: QuantizeOntoGuard,
435    ) -> Result<Arc<dyn QuantMethod>>
436    where
437        Self: Sized,
438    {
439        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
440        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
441        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
442            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
443            QuantizedSerdeType::Unquant => {
444                UnquantLinear::deserialize_ext_bias(data, device, guard)?
445            }
446            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
447            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
448            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
449        };
450        Ok(Arc::new(Self { weight, bias }))
451    }
452}
453
454#[derive(Debug)]
455/// This layer has no parallelization
456pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
457
458impl ReplicatedLayer {
459    pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
460        let dev = lin.weight().device().clone();
461        let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
462        let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
463        Ok(this)
464    }
465
466    #[allow(clippy::new_ret_no_self)]
467    pub fn new(
468        in_dim: usize,
469        out_dim: usize,
470        config: &Option<QuantizedConfig>,
471        bias: bool,
472        vb: ShardedVarBuilder,
473    ) -> Result<Arc<dyn QuantMethod>> {
474        let base_vb = vb.clone();
475        let vb = if should_apply_immediate_isq(&vb) {
476            vb.set_device(Device::Cpu)
477        } else {
478            vb
479        };
480
481        let layer = if let Some(quant_conf) = &config {
482            match quant_conf {
483                QuantizedConfig::GptqAwq { .. } => {
484                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
485                }
486                QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
487                    in_dim,
488                    out_dim,
489                    quant_conf,
490                    bias,
491                    Default::default(),
492                    vb.clone(),
493                )?,
494                QuantizedConfig::Bitsandbytes { .. } => {
495                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
496                }
497                QuantizedConfig::Afq { .. } => {
498                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
499                }
500            }
501        } else {
502            // Handle the case where the layer is dummy (no tensors)
503            if !vb.contains_tensor("weight") {
504                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
505                Arc::new(layer) as Arc<dyn QuantMethod>
506            } else {
507                let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
508                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
509
510                let bias = if bias {
511                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
512                } else {
513                    None
514                };
515                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
516                    Linear::new(weight, bias),
517                ))?;
518                Arc::new(layer) as Arc<dyn QuantMethod>
519            }
520        };
521
522        let this_unquant = Arc::new(Self(layer));
523        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
524        Ok(this)
525    }
526}
527
528impl QuantMethod for ReplicatedLayer {
529    fn new(_method: QuantMethodConfig) -> Result<Self>
530    where
531        Self: Sized,
532    {
533        candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
534    }
535
536    fn forward(&self, a: &Tensor) -> Result<Tensor> {
537        self.0.forward(a)
538    }
539
540    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
541        self.0.add_delta_w(delta)
542    }
543
544    fn dequantize_w(&self) -> Result<Tensor> {
545        self.0.dequantize_w()
546    }
547
548    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
549        self.0.dtype_and_device()
550    }
551
552    fn begin_track_stats(&mut self) -> Result<()> {
553        Arc::get_mut(&mut self.0)
554            .context("Failed to get &mut to weight")?
555            .begin_track_stats()
556    }
557
558    fn end_track_stats(&self) -> Result<Tensor> {
559        self.0.end_track_stats()
560    }
561
562    fn quantized_act_type(&self) -> Option<candle_core::DType> {
563        self.0.quantized_act_type()
564    }
565
566    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
567        self.0.unquant_weight_bias()
568    }
569
570    fn apply_isq(
571        self: Arc<Self>,
572        dtype: Option<crate::IsqType>,
573        device: candle_core::Device,
574        n_quantized: &std::sync::atomic::AtomicUsize,
575        imatrix_weight: Option<Vec<f32>>,
576        guard: QuantizeOntoGuard,
577    ) -> Result<Arc<dyn QuantMethod>> {
578        self.0
579            .clone()
580            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
581    }
582
583    fn is_distributed(&self) -> Option<DistributedKind> {
584        Some(DistributedKind::Replicated)
585    }
586}
587
588impl QuantizedSerde for ReplicatedLayer {
589    fn isq_serde_supported(&self) -> bool {
590        self.0.isq_serde_supported()
591    }
592    fn name(&self) -> &'static str {
593        self.0.name()
594    }
595    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
596        self.0.serialize()
597    }
598    fn deserialize(
599        data: std::borrow::Cow<[u8]>,
600        device: &candle_core::Device,
601        comm: &Arc<crate::Comm>,
602        guard: QuantizeOntoGuard,
603    ) -> Result<Arc<dyn QuantMethod>>
604    where
605        Self: Sized,
606    {
607        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
608        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
609        let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
610            QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
611            QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
612            QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
613            QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
614            QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
615        };
616        Ok(Arc::new(Self(deserialized)))
617    }
618}
619
620#[derive(Debug)]
621pub struct PackedExperts {
622    pub gate_proj: Vec<Arc<dyn QuantMethod>>,
623    pub up_proj: Vec<Arc<dyn QuantMethod>>,
624    pub down_proj: Vec<Arc<dyn QuantMethod>>,
625}
626
627impl PackedExperts {
628    /// Note: we only support AFQ and unquantized here because they are the only ones that support indexed.
629    #[allow(clippy::too_many_arguments)]
630    pub fn new(
631        num_local_experts: usize,
632        hidden_size: usize,
633        intermediate_size: usize,
634        config: &Option<QuantizedConfig>,
635        bias: bool,
636        comm: &Arc<crate::Comm>,
637        vb: ShardedVarBuilder,
638    ) -> Result<Self> {
639        if bias {
640            candle_core::bail!("PackedExperts does not support bias.");
641        }
642
643        let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
644            // GPTQ and BNB do not support tensor parallelism
645            if comm.world_size() != 1 {
646                candle_core::bail!(
647                    "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
648                    comm.world_size()
649                );
650            }
651
652            match quant_conf {
653                QuantizedConfig::Afq { .. } => {
654                    if !vb.contains_tensor("gate_up_proj")
655                        || !vb.contains_tensor("gate_up_proj.weight")
656                    {
657                        candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
658                    }
659
660                    let base_vb = vb.clone();
661
662                    let vb_gate_proj = if should_apply_immediate_isq(&vb) {
663                        vb.pp("gate_proj").set_device(Device::Cpu)
664                    } else {
665                        vb.pp("gate_proj")
666                    };
667                    let vb_up_proj = if should_apply_immediate_isq(&vb) {
668                        vb.pp("up_proj").set_device(Device::Cpu)
669                    } else {
670                        vb.pp("up_proj")
671                    };
672                    let vb_down_proj = if should_apply_immediate_isq(&vb) {
673                        vb.pp("down_proj").set_device(Device::Cpu)
674                    } else {
675                        vb.pp("down_proj")
676                    };
677                    let mut gate_proj = AfqLayer::afq_packed_linear_b(
678                        num_local_experts,
679                        hidden_size,
680                        intermediate_size,
681                        quant_conf,
682                        bias,
683                        vb_gate_proj,
684                    )?;
685                    let mut up_proj = AfqLayer::afq_packed_linear_b(
686                        num_local_experts,
687                        hidden_size,
688                        intermediate_size,
689                        quant_conf,
690                        bias,
691                        vb_up_proj,
692                    )?;
693                    let mut down_proj = AfqLayer::afq_packed_linear_b(
694                        num_local_experts,
695                        intermediate_size,
696                        hidden_size,
697                        quant_conf,
698                        bias,
699                        vb_down_proj,
700                    )?;
701
702                    gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
703                    up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
704                    down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
705
706                    (vec![gate_proj], vec![up_proj], vec![down_proj])
707                }
708                _ => candle_core::bail!(
709                    "PackedExperts with quantization config only allows AFQ quantization"
710                ),
711            }
712        } else if !vb.contains_tensor("gate_up_proj") {
713            // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
714            let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
715            let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
716            let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
717            for _ in 0..num_local_experts {
718                gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
719                us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
720                ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
721            }
722            (gs, us, ds)
723        } else {
724            // Parallelized like:
725            // Each gpu holds all experts.
726            // Gate/Up proj is parallelized on dim 2 (column)
727            // Down proj is parallelized on dim 1 (row)
728            // All reduce at the end.
729
730            // Handle the case where the layer is dummy (no tensors)
731            let gate_up_block_size = intermediate_size / comm.world_size();
732            let gate_up_start = gate_up_block_size * comm.rank();
733
734            // Gate is right before Up in the gate_up
735            let shard_gate = Shard::Offset {
736                dim: 2,
737                offset: gate_up_start,
738                len: gate_up_block_size,
739            };
740            let shard_up = Shard::Offset {
741                dim: 2,
742                offset: intermediate_size + gate_up_start,
743                len: gate_up_block_size,
744            };
745            let shard_down = Shard::Simple {
746                dim: 1,
747                rank: comm.rank(),
748                world_size: comm.world_size(),
749            };
750
751            let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
752                vb.pp("gate_up_proj").set_device(Device::Cpu)
753            } else {
754                vb.pp("gate_up_proj")
755            };
756            let vb_down_proj = if should_apply_immediate_isq(&vb) {
757                vb.pp("down_proj").set_device(Device::Cpu)
758            } else {
759                vb.pp("down_proj")
760            };
761
762            let gate_proj = vb
763                .get_with_hints(
764                    (num_local_experts, hidden_size, intermediate_size * 2),
765                    "gate_up_proj",
766                    shard_gate,
767                )?
768                .t()?
769                .contiguous()?;
770            let up_proj = vb
771                .get_with_hints(
772                    (num_local_experts, hidden_size, intermediate_size * 2),
773                    "gate_up_proj",
774                    shard_up,
775                )?
776                .t()?
777                .contiguous()?;
778            let down_proj = vb
779                .get_with_hints(
780                    (num_local_experts, intermediate_size, hidden_size),
781                    "down_proj",
782                    shard_down,
783                )?
784                .t()?
785                .contiguous()?;
786
787            let gc = gate_proj.chunk(num_local_experts, 0)?;
788            let uc = up_proj.chunk(num_local_experts, 0)?;
789            let dc = down_proj.chunk(num_local_experts, 0)?;
790            drop((gate_proj, up_proj, down_proj));
791
792            let mut gs = Vec::new();
793            let mut us = Vec::new();
794            let mut ds = Vec::new();
795            for ((mut gate_proj, mut up_proj), mut down_proj) in
796                gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
797            {
798                gate_proj = gate_proj.squeeze(0)?;
799                up_proj = up_proj.squeeze(0)?;
800                down_proj = down_proj.squeeze(0)?;
801                let gate_proj = merge_lora_weights(
802                    &vb,
803                    gate_proj,
804                    hidden_size,
805                    intermediate_size * 2,
806                    shard_gate,
807                )?;
808                let up_proj =
809                    merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
810                let down_proj =
811                    merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
812
813                let mut gate_proj: Arc<dyn QuantMethod> =
814                    Arc::new(<UnquantLinear as QuantMethod>::new(
815                        QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
816                    )?);
817                gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
818                let mut up_proj: Arc<dyn QuantMethod> =
819                    Arc::new(<UnquantLinear as QuantMethod>::new(
820                        QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
821                    )?);
822                up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
823                let mut down_proj: Arc<dyn QuantMethod> =
824                    Arc::new(<UnquantLinear as QuantMethod>::new(
825                        QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
826                    )?);
827                down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
828                gs.push(gate_proj);
829                us.push(up_proj);
830                ds.push(down_proj);
831            }
832            (gs, us, ds)
833        };
834
835        Ok(Self {
836            gate_proj,
837            up_proj,
838            down_proj,
839        })
840    }
841}
842
843pub struct FusedExperts {
844    pub fused_gate_proj: Arc<dyn QuantMethod>,
845    pub fused_up_proj: Arc<dyn QuantMethod>,
846    pub fused_down_proj: Arc<dyn QuantMethod>,
847}
848
849impl FusedExperts {
850    pub fn new(
851        hidden_size: usize,
852        moe_intermediate_size: usize,
853        num_experts: usize,
854        quantization_config: &Option<QuantizedConfig>,
855        vb: ShardedVarBuilder,
856    ) -> Result<Self> {
857        if !vb.device().is_metal() {
858            candle_core::bail!("FastMoeMlp requires Metal.");
859        }
860
861        let (fused_gate_proj, fused_up_proj, fused_down_proj) =
862            if matches!(&quantization_config, Some(QuantizedConfig::Afq { .. })) {
863                let quantization_config = quantization_config.as_ref().unwrap();
864
865                let fused_gate_proj = AfqLayer::afq_packed_linear_b(
866                    num_experts,
867                    hidden_size,
868                    moe_intermediate_size,
869                    quantization_config,
870                    false,
871                    vb.pp("switch_mlp.gate_proj"),
872                )?;
873                let fused_up_proj = AfqLayer::afq_packed_linear_b(
874                    num_experts,
875                    hidden_size,
876                    moe_intermediate_size,
877                    quantization_config,
878                    false,
879                    vb.pp("switch_mlp.up_proj"),
880                )?;
881                let fused_down_proj = AfqLayer::afq_packed_linear_b(
882                    num_experts,
883                    moe_intermediate_size,
884                    hidden_size,
885                    quantization_config,
886                    false,
887                    vb.pp("switch_mlp.down_proj"),
888                )?;
889
890                (fused_gate_proj, fused_up_proj, fused_down_proj)
891            } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
892                let experts_vb = vb.pp("experts");
893                let mut gate_proj_vec = Vec::new();
894                let mut up_proj_vec = Vec::new();
895                let mut down_proj_vec = Vec::new();
896                for i in 0..num_experts {
897                    let vb = experts_vb.pp(i);
898
899                    let gate_proj = crate::linear_no_bias(
900                        hidden_size,
901                        moe_intermediate_size,
902                        quantization_config,
903                        vb.pp("gate_proj.weight"),
904                    )?;
905                    let up_proj = crate::linear_no_bias(
906                        hidden_size,
907                        moe_intermediate_size,
908                        quantization_config,
909                        vb.pp("up_proj.weight"),
910                    )?;
911                    let down_proj = crate::linear_no_bias(
912                        moe_intermediate_size,
913                        hidden_size,
914                        quantization_config,
915                        vb.pp("down_proj.weight"),
916                    )?;
917
918                    gate_proj_vec.push(gate_proj.dequantize_w()?);
919                    up_proj_vec.push(up_proj.dequantize_w()?);
920                    down_proj_vec.push(down_proj.dequantize_w()?);
921                }
922
923                let mut gate_proj: Arc<dyn QuantMethod> =
924                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
925                        Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
926                    ))?);
927                let mut up_proj: Arc<dyn QuantMethod> =
928                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
929                        Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
930                    ))?);
931                let mut down_proj: Arc<dyn QuantMethod> =
932                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
933                        Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
934                    ))?);
935                gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
936                up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
937                down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
938
939                (gate_proj, up_proj, down_proj)
940            } else {
941                let experts_vb = vb.pp("experts");
942                let mut gate_proj_vec = Vec::new();
943                let mut up_proj_vec = Vec::new();
944                let mut down_proj_vec = Vec::new();
945                for i in 0..num_experts {
946                    let vb = experts_vb.pp(i);
947                    let gate_proj =
948                        vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
949                    let up_proj = vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
950                    let down_proj =
951                        vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
952
953                    gate_proj_vec.push(gate_proj);
954                    up_proj_vec.push(up_proj);
955                    down_proj_vec.push(down_proj);
956                }
957
958                let mut gate_proj: Arc<dyn QuantMethod> =
959                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
960                        Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
961                    ))?);
962                let mut up_proj: Arc<dyn QuantMethod> =
963                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
964                        Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
965                    ))?);
966                let mut down_proj: Arc<dyn QuantMethod> =
967                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
968                        Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
969                    ))?);
970                gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
971                up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
972                down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
973
974                (gate_proj, up_proj, down_proj)
975            };
976
977        Ok(Self {
978            fused_gate_proj,
979            fused_up_proj,
980            fused_down_proj,
981        })
982    }
983}
984
985/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
986pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
987    if comm.world_size() == 1 {
988        return Shard::default();
989    }
990
991    // Tensor parallelism case
992
993    // We may need to replicate the kv heads
994    let kv_replicate = if comm.world_size() > total_num_kv_heads {
995        comm.world_size() / total_num_kv_heads
996    } else {
997        return Shard::Simple {
998            dim: 0,
999            rank: comm.rank(),
1000            world_size: comm.world_size(),
1001        };
1002    };
1003
1004    let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1005    let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1006    Shard::Offset {
1007        dim: 0,
1008        offset: kv_shard_id * head_dim,
1009        len: head_dim,
1010    }
1011}
1012
1013/// Compute the number of KV groups, taking into account KV head replication.
1014pub fn compute_n_kv_groups(
1015    total_num_kv_heads: usize,
1016    num_attention_heads: usize,
1017    comm: &Comm,
1018) -> usize {
1019    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1020        comm.world_size() / total_num_kv_heads
1021    } else {
1022        1
1023    };
1024    if kv_replicate != 0 {
1025        (num_attention_heads / total_num_kv_heads) / kv_replicate
1026    } else {
1027        num_attention_heads / total_num_kv_heads
1028    }
1029}