mistralrs_quant/distributed/
layers.rs

1use std::sync::Arc;
2
3use candle_core::{Context, Device, Result, Tensor};
4use candle_nn::Linear;
5
6use crate::{
7    blockwise_fp8::blockwise_fp8_linear_b,
8    distributed,
9    gptq::gptq_linear,
10    lora::merge_lora_weights,
11    should_apply_immediate_isq,
12    utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
13    AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, QuantMethod,
14    QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, QuantizedSerdeType,
15    Shard, ShardedVarBuilder, UnquantLinear,
16};
17
18use super::{Comm, SumAllReduce};
19
20fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
21    Shard::Simple {
22        dim,
23        rank,
24        world_size,
25    }
26}
27
28/// This layer has a weight that is parallelized along the input dimension,
29/// returning the "full" output dimension.
30#[derive(Debug)]
31pub struct RowParallelLayer {
32    weight: Arc<dyn QuantMethod>,
33    bias: Option<Tensor>,
34    all_reduce: distributed::SumAllReduce,
35}
36
37impl RowParallelLayer {
38    #[allow(clippy::new_ret_no_self)]
39    pub fn new(
40        in_dim: usize,
41        out_dim: usize,
42        config: &Option<QuantizedConfig>,
43        bias: bool,
44        comm: &Arc<crate::Comm>,
45        vb: ShardedVarBuilder,
46    ) -> Result<Arc<dyn QuantMethod>> {
47        let rank = comm.rank();
48        let world_size = comm.world_size();
49        let shard = shard(1, rank, world_size);
50
51        let base_vb = vb.clone();
52        let vb = if should_apply_immediate_isq(&vb) {
53            vb.set_device(Device::Cpu)
54        } else {
55            vb
56        };
57
58        let weight = if let Some(quant_conf) = &config {
59            // GPTQ and BNB do not support tensor parallelism
60            if matches!(
61                quant_conf,
62                QuantizedConfig::GptqAwq { .. }
63                    | QuantizedConfig::Bitsandbytes { .. }
64                    | QuantizedConfig::Afq { .. }
65            ) && comm.world_size() != 1
66            {
67                candle_core::bail!(
68                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
69                    comm.world_size()
70                );
71            }
72
73            match quant_conf {
74                QuantizedConfig::GptqAwq { .. } => {
75                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
76                }
77                QuantizedConfig::Fp8 { .. } => {
78                    // NOTE: no bias for fp8 as it might be parallelized
79                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
80                }
81                QuantizedConfig::Bitsandbytes { .. } => {
82                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
83                }
84                QuantizedConfig::Afq { .. } => {
85                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
86                }
87            }
88        } else {
89            // Handle the case where the layer is dummy (no tensors)
90            if !vb.contains_tensor("weight") {
91                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
92                Arc::new(layer) as Arc<dyn QuantMethod>
93            } else {
94                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
95                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
96
97                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
98                    Linear::new(weight, None),
99                ))?;
100                Arc::new(layer) as Arc<dyn QuantMethod>
101            }
102        };
103
104        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
105        let bias = if bias && vb.contains_tensor("bias") {
106            Some(vb.get((out_dim,), "bias")?)
107        } else {
108            None
109        };
110
111        let this_unquant = Arc::new(Self {
112            weight,
113            bias,
114            all_reduce: distributed::SumAllReduce::new(comm),
115        });
116        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
117        Ok(this)
118    }
119}
120
121impl QuantMethod for RowParallelLayer {
122    fn new(_method: QuantMethodConfig) -> Result<Self>
123    where
124        Self: Sized,
125    {
126        candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
127    }
128
129    fn forward(&self, a: &Tensor) -> Result<Tensor> {
130        let mut xs = self.weight.forward(a)?;
131        xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
132        if let Some(bias) = &self.bias {
133            xs = xs.broadcast_add(bias)?;
134        }
135        Ok(xs)
136    }
137
138    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139        let weight = self.weight.add_delta_w(delta)?;
140        Ok(Arc::new(Self {
141            weight,
142            bias: self.bias.clone(),
143            all_reduce: self.all_reduce.clone(),
144        }))
145    }
146
147    fn dequantize_w(&self) -> Result<Tensor> {
148        self.weight.dequantize_w()
149    }
150
151    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
152        self.weight.dtype_and_device()
153    }
154
155    fn begin_track_stats(&mut self) -> Result<()> {
156        Arc::get_mut(&mut self.weight)
157            .context("Failed to get &mut to weight")?
158            .begin_track_stats()
159    }
160
161    fn end_track_stats(&self) -> Result<Tensor> {
162        self.weight.end_track_stats()
163    }
164
165    fn quantized_act_type(&self) -> Option<candle_core::DType> {
166        self.weight.quantized_act_type()
167    }
168
169    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
170        self.weight.unquant_weight_bias()
171    }
172
173    fn apply_isq(
174        self: Arc<Self>,
175        dtype: Option<crate::IsqType>,
176        device: candle_core::Device,
177        n_quantized: &std::sync::atomic::AtomicUsize,
178        imatrix_weight: Option<Vec<f32>>,
179        guard: QuantizeOntoGuard,
180    ) -> Result<Arc<dyn QuantMethod>> {
181        let weight =
182            self.weight
183                .clone()
184                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
185        let bias = match &self.bias {
186            Some(b) => {
187                let (dtype, device) = weight.dtype_and_device();
188                Some(b.to_device(&device)?.to_dtype(dtype)?)
189            }
190            None => None,
191        };
192        Ok(Arc::new(Self {
193            weight,
194            bias,
195            all_reduce: self.all_reduce.clone(),
196        }))
197    }
198
199    fn is_distributed(&self) -> Option<DistributedKind> {
200        Some(DistributedKind::RowParallel)
201    }
202}
203
204impl QuantizedSerde for RowParallelLayer {
205    fn isq_serde_supported(&self) -> bool {
206        self.weight.isq_serde_supported()
207    }
208    fn name(&self) -> &'static str {
209        self.weight.name()
210    }
211    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
212        self.weight.serialize_with_bias(self.bias.clone())
213    }
214    fn deserialize(
215        data: std::borrow::Cow<[u8]>,
216        device: &candle_core::Device,
217        comm: &Arc<crate::Comm>,
218        guard: QuantizeOntoGuard,
219    ) -> Result<Arc<dyn QuantMethod>>
220    where
221        Self: Sized,
222    {
223        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
224        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
225        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
226            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
227            QuantizedSerdeType::Unquant => {
228                UnquantLinear::deserialize_ext_bias(data, device, guard)?
229            }
230            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
231            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
232            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
233        };
234        Ok(Arc::new(Self {
235            weight,
236            bias,
237            all_reduce: SumAllReduce::new(comm),
238        }))
239    }
240}
241
242#[derive(Debug)]
243/// This layer has a weight that is parallelized along the output dimension,
244/// taking the "full" input dimension.
245pub struct ColumnParallelLayer {
246    weight: Arc<dyn QuantMethod>,
247    bias: Option<Tensor>,
248}
249
250impl ColumnParallelLayer {
251    #[allow(clippy::new_ret_no_self)]
252    pub fn new_with_shard(
253        in_dim: usize,
254        out_dim: usize,
255        config: &Option<QuantizedConfig>,
256        bias: bool,
257        comm: &Arc<crate::Comm>,
258        shard: Shard,
259        vb: ShardedVarBuilder,
260    ) -> Result<Arc<dyn QuantMethod>> {
261        let base_vb = vb.clone();
262        let vb = if should_apply_immediate_isq(&vb) {
263            vb.set_device(Device::Cpu)
264        } else {
265            vb
266        };
267
268        let weight = if let Some(quant_conf) = &config {
269            // GPTQ and BNB do not support tensor parallelism
270            if matches!(
271                quant_conf,
272                QuantizedConfig::GptqAwq { .. }
273                    | QuantizedConfig::Bitsandbytes { .. }
274                    | QuantizedConfig::Afq { .. }
275            ) && comm.world_size() != 1
276            {
277                candle_core::bail!(
278                    "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
279                    comm.world_size()
280                );
281            }
282
283            match quant_conf {
284                QuantizedConfig::GptqAwq { .. } => {
285                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
286                }
287                QuantizedConfig::Fp8 { .. } => {
288                    // NOTE: no bias for fp8 as it might be parallelized
289                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
290                }
291                QuantizedConfig::Bitsandbytes { .. } => {
292                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
293                }
294                QuantizedConfig::Afq { .. } => {
295                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
296                }
297            }
298        } else {
299            // Handle the case where the layer is dummy (no tensors)
300            if !vb.contains_tensor("weight") {
301                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
302                Arc::new(layer) as Arc<dyn QuantMethod>
303            } else {
304                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
305                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
306
307                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
308                    Linear::new(weight, None),
309                ))?;
310                Arc::new(layer) as Arc<dyn QuantMethod>
311            }
312        };
313
314        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
315        let bias = if bias && vb.contains_tensor("bias") {
316            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
317        } else {
318            None
319        };
320
321        let this_unquant = Arc::new(Self { weight, bias });
322        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
323        Ok(this)
324    }
325
326    #[allow(clippy::new_ret_no_self)]
327    pub fn new(
328        in_dim: usize,
329        out_dim: usize,
330        config: &Option<QuantizedConfig>,
331        bias: bool,
332        comm: &Arc<crate::Comm>,
333        vb: ShardedVarBuilder,
334    ) -> Result<Arc<dyn QuantMethod>> {
335        let rank = comm.rank();
336        let world_size = comm.world_size();
337        let shard = shard(0, rank, world_size);
338
339        Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
340    }
341
342    pub fn new_merged(
343        in_dim: usize,
344        out_dim: usize,
345        chunks: usize,
346        config: &Option<QuantizedConfig>,
347        bias: bool,
348        comm: &Arc<crate::Comm>,
349        vb: ShardedVarBuilder,
350    ) -> Result<Vec<Arc<dyn QuantMethod>>> {
351        let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
352        for chunk_idx in 0..chunks {
353            let layer = ColumnParallelLayer::new_with_shard(
354                in_dim,
355                out_dim,
356                config,
357                bias,
358                comm,
359                shard(
360                    0,
361                    chunk_idx * comm.world_size() + comm.rank(),
362                    chunks * comm.world_size(),
363                ),
364                vb.clone(),
365            )?;
366            vec_layers.push(layer);
367        }
368        Ok(vec_layers)
369    }
370}
371
372impl QuantMethod for ColumnParallelLayer {
373    fn new(_method: QuantMethodConfig) -> Result<Self>
374    where
375        Self: Sized,
376    {
377        candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
378    }
379
380    fn forward(&self, a: &Tensor) -> Result<Tensor> {
381        let mut xs = self.weight.forward(a)?;
382        if let Some(bias) = &self.bias {
383            xs = xs.broadcast_add(bias)?;
384        }
385        Ok(xs)
386    }
387
388    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
389        let weight = self.weight.add_delta_w(delta)?;
390        Ok(Arc::new(Self {
391            weight,
392            bias: self.bias.clone(),
393        }))
394    }
395
396    fn dequantize_w(&self) -> Result<Tensor> {
397        self.weight.dequantize_w()
398    }
399
400    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
401        self.weight.dtype_and_device()
402    }
403
404    fn begin_track_stats(&mut self) -> Result<()> {
405        Arc::get_mut(&mut self.weight)
406            .context("Failed to get &mut to weight")?
407            .begin_track_stats()
408    }
409
410    fn end_track_stats(&self) -> Result<Tensor> {
411        self.weight.end_track_stats()
412    }
413
414    fn quantized_act_type(&self) -> Option<candle_core::DType> {
415        self.weight.quantized_act_type()
416    }
417
418    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
419        self.weight.unquant_weight_bias()
420    }
421
422    fn apply_isq(
423        self: Arc<Self>,
424        dtype: Option<crate::IsqType>,
425        device: candle_core::Device,
426        n_quantized: &std::sync::atomic::AtomicUsize,
427        imatrix_weight: Option<Vec<f32>>,
428        guard: QuantizeOntoGuard,
429    ) -> Result<Arc<dyn QuantMethod>> {
430        let weight =
431            self.weight
432                .clone()
433                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
434        let bias = match &self.bias {
435            Some(b) => {
436                let (dtype, device) = weight.dtype_and_device();
437                Some(b.to_device(&device)?.to_dtype(dtype)?)
438            }
439            None => None,
440        };
441        Ok(Arc::new(Self { weight, bias }))
442    }
443
444    fn is_distributed(&self) -> Option<DistributedKind> {
445        Some(DistributedKind::ColumnParallel)
446    }
447}
448
449impl QuantizedSerde for ColumnParallelLayer {
450    fn isq_serde_supported(&self) -> bool {
451        self.weight.isq_serde_supported()
452    }
453    fn name(&self) -> &'static str {
454        self.weight.name()
455    }
456    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
457        self.weight.serialize_with_bias(self.bias.clone())
458    }
459    fn deserialize(
460        data: std::borrow::Cow<[u8]>,
461        device: &candle_core::Device,
462        _comm: &Arc<crate::Comm>,
463        guard: QuantizeOntoGuard,
464    ) -> Result<Arc<dyn QuantMethod>>
465    where
466        Self: Sized,
467    {
468        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
469        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
470        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
471            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
472            QuantizedSerdeType::Unquant => {
473                UnquantLinear::deserialize_ext_bias(data, device, guard)?
474            }
475            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
476            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
477            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
478        };
479        Ok(Arc::new(Self { weight, bias }))
480    }
481}
482
483#[derive(Debug)]
484/// This layer has no parallelization
485pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
486
487impl ReplicatedLayer {
488    pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
489        let dev = lin.weight().device().clone();
490        let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
491        let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
492        Ok(this)
493    }
494
495    #[allow(clippy::new_ret_no_self)]
496    pub fn new(
497        in_dim: usize,
498        out_dim: usize,
499        config: &Option<QuantizedConfig>,
500        bias: bool,
501        vb: ShardedVarBuilder,
502    ) -> Result<Arc<dyn QuantMethod>> {
503        let base_vb = vb.clone();
504        let vb = if should_apply_immediate_isq(&vb) {
505            vb.set_device(Device::Cpu)
506        } else {
507            vb
508        };
509
510        let layer = if let Some(quant_conf) = &config {
511            match quant_conf {
512                QuantizedConfig::GptqAwq { .. } => {
513                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
514                }
515                QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
516                    in_dim,
517                    out_dim,
518                    quant_conf,
519                    bias,
520                    Default::default(),
521                    vb.clone(),
522                )?,
523                QuantizedConfig::Bitsandbytes { .. } => {
524                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
525                }
526                QuantizedConfig::Afq { .. } => {
527                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
528                }
529            }
530        } else {
531            // Handle the case where the layer is dummy (no tensors)
532            if !vb.contains_tensor("weight") {
533                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
534                Arc::new(layer) as Arc<dyn QuantMethod>
535            } else {
536                let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
537                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
538
539                let bias = if bias {
540                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
541                } else {
542                    None
543                };
544                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
545                    Linear::new(weight, bias),
546                ))?;
547                Arc::new(layer) as Arc<dyn QuantMethod>
548            }
549        };
550
551        let this_unquant = Arc::new(Self(layer));
552        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
553        Ok(this)
554    }
555}
556
557impl QuantMethod for ReplicatedLayer {
558    fn new(_method: QuantMethodConfig) -> Result<Self>
559    where
560        Self: Sized,
561    {
562        candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
563    }
564
565    fn forward(&self, a: &Tensor) -> Result<Tensor> {
566        self.0.forward(a)
567    }
568
569    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
570        self.0.add_delta_w(delta)
571    }
572
573    fn dequantize_w(&self) -> Result<Tensor> {
574        self.0.dequantize_w()
575    }
576
577    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
578        self.0.dtype_and_device()
579    }
580
581    fn begin_track_stats(&mut self) -> Result<()> {
582        Arc::get_mut(&mut self.0)
583            .context("Failed to get &mut to weight")?
584            .begin_track_stats()
585    }
586
587    fn end_track_stats(&self) -> Result<Tensor> {
588        self.0.end_track_stats()
589    }
590
591    fn quantized_act_type(&self) -> Option<candle_core::DType> {
592        self.0.quantized_act_type()
593    }
594
595    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
596        self.0.unquant_weight_bias()
597    }
598
599    fn apply_isq(
600        self: Arc<Self>,
601        dtype: Option<crate::IsqType>,
602        device: candle_core::Device,
603        n_quantized: &std::sync::atomic::AtomicUsize,
604        imatrix_weight: Option<Vec<f32>>,
605        guard: QuantizeOntoGuard,
606    ) -> Result<Arc<dyn QuantMethod>> {
607        self.0
608            .clone()
609            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
610    }
611
612    fn is_distributed(&self) -> Option<DistributedKind> {
613        Some(DistributedKind::Replicated)
614    }
615}
616
617impl QuantizedSerde for ReplicatedLayer {
618    fn isq_serde_supported(&self) -> bool {
619        self.0.isq_serde_supported()
620    }
621    fn name(&self) -> &'static str {
622        self.0.name()
623    }
624    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
625        self.0.serialize()
626    }
627    fn deserialize(
628        data: std::borrow::Cow<[u8]>,
629        device: &candle_core::Device,
630        comm: &Arc<crate::Comm>,
631        guard: QuantizeOntoGuard,
632    ) -> Result<Arc<dyn QuantMethod>>
633    where
634        Self: Sized,
635    {
636        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
637        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
638        let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
639            QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
640            QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
641            QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
642            QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
643            QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
644        };
645        Ok(Arc::new(Self(deserialized)))
646    }
647}
648
649#[derive(Debug)]
650pub struct PackedExperts {
651    pub gate_proj: Vec<Arc<dyn QuantMethod>>,
652    pub up_proj: Vec<Arc<dyn QuantMethod>>,
653    pub down_proj: Vec<Arc<dyn QuantMethod>>,
654}
655
656impl PackedExperts {
657    /// Note: we only support AFQ and unquantized here because they are the only ones that support indexed.
658    #[allow(clippy::too_many_arguments)]
659    pub fn new(
660        num_local_experts: usize,
661        hidden_size: usize,
662        intermediate_size: usize,
663        config: &Option<QuantizedConfig>,
664        bias: bool,
665        comm: &Arc<crate::Comm>,
666        vb: ShardedVarBuilder,
667    ) -> Result<Self> {
668        if bias {
669            candle_core::bail!("PackedExperts does not support bias.");
670        }
671
672        let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
673            // GPTQ and BNB do not support tensor parallelism
674            if comm.world_size() != 1 {
675                candle_core::bail!(
676                    "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
677                    comm.world_size()
678                );
679            }
680
681            match quant_conf {
682                QuantizedConfig::Afq { .. } => {
683                    if !vb.contains_tensor("gate_up_proj")
684                        || !vb.contains_tensor("gate_up_proj.weight")
685                    {
686                        candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
687                    }
688
689                    let base_vb = vb.clone();
690
691                    let vb_gate_proj = if should_apply_immediate_isq(&vb) {
692                        vb.pp("gate_proj").set_device(Device::Cpu)
693                    } else {
694                        vb.pp("gate_proj")
695                    };
696                    let vb_up_proj = if should_apply_immediate_isq(&vb) {
697                        vb.pp("up_proj").set_device(Device::Cpu)
698                    } else {
699                        vb.pp("up_proj")
700                    };
701                    let vb_down_proj = if should_apply_immediate_isq(&vb) {
702                        vb.pp("down_proj").set_device(Device::Cpu)
703                    } else {
704                        vb.pp("down_proj")
705                    };
706                    let mut gate_proj = AfqLayer::afq_packed_linear_b(
707                        num_local_experts,
708                        hidden_size,
709                        intermediate_size,
710                        quant_conf,
711                        bias,
712                        vb_gate_proj,
713                    )?;
714                    let mut up_proj = AfqLayer::afq_packed_linear_b(
715                        num_local_experts,
716                        hidden_size,
717                        intermediate_size,
718                        quant_conf,
719                        bias,
720                        vb_up_proj,
721                    )?;
722                    let mut down_proj = AfqLayer::afq_packed_linear_b(
723                        num_local_experts,
724                        intermediate_size,
725                        hidden_size,
726                        quant_conf,
727                        bias,
728                        vb_down_proj,
729                    )?;
730
731                    gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
732                    up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
733                    down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
734
735                    (vec![gate_proj], vec![up_proj], vec![down_proj])
736                }
737                _ => candle_core::bail!(
738                    "PackedExperts with quantization config only allows AFQ quantization"
739                ),
740            }
741        } else if !vb.contains_tensor("gate_up_proj") {
742            // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
743            let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
744            let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
745            let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
746            for _ in 0..num_local_experts {
747                gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
748                us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
749                ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
750            }
751            (gs, us, ds)
752        } else {
753            // Parallelized like:
754            // Each gpu holds all experts.
755            // Gate/Up proj is parallelized on dim 2 (column)
756            // Down proj is parallelized on dim 1 (row)
757            // All reduce at the end.
758
759            // Handle the case where the layer is dummy (no tensors)
760            let gate_up_block_size = intermediate_size / comm.world_size();
761            let gate_up_start = gate_up_block_size * comm.rank();
762
763            // Gate is right before Up in the gate_up
764            let shard_gate = Shard::Offset {
765                dim: 2,
766                offset: gate_up_start,
767                len: gate_up_block_size,
768            };
769            let shard_up = Shard::Offset {
770                dim: 2,
771                offset: intermediate_size + gate_up_start,
772                len: gate_up_block_size,
773            };
774            let shard_down = Shard::Simple {
775                dim: 1,
776                rank: comm.rank(),
777                world_size: comm.world_size(),
778            };
779
780            let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
781                vb.pp("gate_up_proj").set_device(Device::Cpu)
782            } else {
783                vb.pp("gate_up_proj")
784            };
785            let vb_down_proj = if should_apply_immediate_isq(&vb) {
786                vb.pp("down_proj").set_device(Device::Cpu)
787            } else {
788                vb.pp("down_proj")
789            };
790
791            let gate_proj = vb
792                .get_with_hints(
793                    (num_local_experts, hidden_size, intermediate_size * 2),
794                    "gate_up_proj",
795                    shard_gate,
796                )?
797                .t()?
798                .contiguous()?;
799            let up_proj = vb
800                .get_with_hints(
801                    (num_local_experts, hidden_size, intermediate_size * 2),
802                    "gate_up_proj",
803                    shard_up,
804                )?
805                .t()?
806                .contiguous()?;
807            let down_proj = vb
808                .get_with_hints(
809                    (num_local_experts, intermediate_size, hidden_size),
810                    "down_proj",
811                    shard_down,
812                )?
813                .t()?
814                .contiguous()?;
815
816            let gc = gate_proj.chunk(num_local_experts, 0)?;
817            let uc = up_proj.chunk(num_local_experts, 0)?;
818            let dc = down_proj.chunk(num_local_experts, 0)?;
819            drop((gate_proj, up_proj, down_proj));
820
821            let mut gs = Vec::new();
822            let mut us = Vec::new();
823            let mut ds = Vec::new();
824            for ((mut gate_proj, mut up_proj), mut down_proj) in
825                gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
826            {
827                gate_proj = gate_proj.squeeze(0)?;
828                up_proj = up_proj.squeeze(0)?;
829                down_proj = down_proj.squeeze(0)?;
830                let gate_proj = merge_lora_weights(
831                    &vb,
832                    gate_proj,
833                    hidden_size,
834                    intermediate_size * 2,
835                    shard_gate,
836                )?;
837                let up_proj =
838                    merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
839                let down_proj =
840                    merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
841
842                let mut gate_proj: Arc<dyn QuantMethod> =
843                    Arc::new(<UnquantLinear as QuantMethod>::new(
844                        QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
845                    )?);
846                gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
847                let mut up_proj: Arc<dyn QuantMethod> =
848                    Arc::new(<UnquantLinear as QuantMethod>::new(
849                        QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
850                    )?);
851                up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
852                let mut down_proj: Arc<dyn QuantMethod> =
853                    Arc::new(<UnquantLinear as QuantMethod>::new(
854                        QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
855                    )?);
856                down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
857                gs.push(gate_proj);
858                us.push(up_proj);
859                ds.push(down_proj);
860            }
861            (gs, us, ds)
862        };
863
864        Ok(Self {
865            gate_proj,
866            up_proj,
867            down_proj,
868        })
869    }
870}
871
872pub struct FusedExperts {
873    pub fused_gate_proj: Arc<dyn QuantMethod>,
874    pub fused_up_proj: Arc<dyn QuantMethod>,
875    pub fused_down_proj: Arc<dyn QuantMethod>,
876}
877
878impl FusedExperts {
879    pub fn new(
880        hidden_size: usize,
881        moe_intermediate_size: usize,
882        num_experts: usize,
883        quantization_config: &Option<QuantizedConfig>,
884        vb: ShardedVarBuilder,
885    ) -> Result<Self> {
886        if !vb.device().is_metal() {
887            candle_core::bail!("FastMoeMlp requires Metal.");
888        }
889
890        let (fused_gate_proj, fused_up_proj, fused_down_proj) =
891            if matches!(&quantization_config, Some(QuantizedConfig::Afq { .. })) {
892                let quantization_config = quantization_config.as_ref().unwrap();
893
894                let fused_gate_proj = AfqLayer::afq_packed_linear_b(
895                    num_experts,
896                    hidden_size,
897                    moe_intermediate_size,
898                    quantization_config,
899                    false,
900                    vb.pp("switch_mlp.gate_proj"),
901                )?;
902                let fused_up_proj = AfqLayer::afq_packed_linear_b(
903                    num_experts,
904                    hidden_size,
905                    moe_intermediate_size,
906                    quantization_config,
907                    false,
908                    vb.pp("switch_mlp.up_proj"),
909                )?;
910                let fused_down_proj = AfqLayer::afq_packed_linear_b(
911                    num_experts,
912                    moe_intermediate_size,
913                    hidden_size,
914                    quantization_config,
915                    false,
916                    vb.pp("switch_mlp.down_proj"),
917                )?;
918
919                (fused_gate_proj, fused_up_proj, fused_down_proj)
920            } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
921                let experts_vb = vb.pp("experts");
922                let mut gate_proj_vec = Vec::new();
923                let mut up_proj_vec = Vec::new();
924                let mut down_proj_vec = Vec::new();
925                for i in 0..num_experts {
926                    let vb = experts_vb.pp(i);
927
928                    let gate_proj = crate::linear_no_bias(
929                        hidden_size,
930                        moe_intermediate_size,
931                        quantization_config,
932                        vb.pp("gate_proj.weight"),
933                    )?;
934                    let up_proj = crate::linear_no_bias(
935                        hidden_size,
936                        moe_intermediate_size,
937                        quantization_config,
938                        vb.pp("up_proj.weight"),
939                    )?;
940                    let down_proj = crate::linear_no_bias(
941                        moe_intermediate_size,
942                        hidden_size,
943                        quantization_config,
944                        vb.pp("down_proj.weight"),
945                    )?;
946
947                    gate_proj_vec.push(gate_proj.dequantize_w()?);
948                    up_proj_vec.push(up_proj.dequantize_w()?);
949                    down_proj_vec.push(down_proj.dequantize_w()?);
950                }
951
952                let mut gate_proj: Arc<dyn QuantMethod> =
953                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
954                        Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
955                    ))?);
956                let mut up_proj: Arc<dyn QuantMethod> =
957                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
958                        Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
959                    ))?);
960                let mut down_proj: Arc<dyn QuantMethod> =
961                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
962                        Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
963                    ))?);
964                gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
965                up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
966                down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
967
968                (gate_proj, up_proj, down_proj)
969            } else {
970                let experts_vb = vb.pp("experts");
971                let mut gate_proj_vec = Vec::new();
972                let mut up_proj_vec = Vec::new();
973                let mut down_proj_vec = Vec::new();
974                for i in 0..num_experts {
975                    let vb = experts_vb.pp(i);
976                    let gate_proj =
977                        vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
978                    let up_proj = vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
979                    let down_proj =
980                        vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
981
982                    gate_proj_vec.push(gate_proj);
983                    up_proj_vec.push(up_proj);
984                    down_proj_vec.push(down_proj);
985                }
986
987                let mut gate_proj: Arc<dyn QuantMethod> =
988                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
989                        Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
990                    ))?);
991                let mut up_proj: Arc<dyn QuantMethod> =
992                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
993                        Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
994                    ))?);
995                let mut down_proj: Arc<dyn QuantMethod> =
996                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
997                        Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
998                    ))?);
999                gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1000                up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1001                down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1002
1003                (gate_proj, up_proj, down_proj)
1004            };
1005
1006        Ok(Self {
1007            fused_gate_proj,
1008            fused_up_proj,
1009            fused_down_proj,
1010        })
1011    }
1012}
1013
1014/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
1015pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1016    if comm.world_size() == 1 {
1017        return Shard::default();
1018    }
1019
1020    // Tensor parallelism case
1021
1022    // We may need to replicate the kv heads
1023    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1024        comm.world_size() / total_num_kv_heads
1025    } else {
1026        return Shard::Simple {
1027            dim: 0,
1028            rank: comm.rank(),
1029            world_size: comm.world_size(),
1030        };
1031    };
1032
1033    let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1034    let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1035    Shard::Offset {
1036        dim: 0,
1037        offset: kv_shard_id * head_dim,
1038        len: head_dim,
1039    }
1040}
1041
1042/// Compute the number of KV groups, taking into account KV head replication.
1043pub fn compute_n_kv_groups(
1044    total_num_kv_heads: usize,
1045    num_attention_heads: usize,
1046    comm: &Comm,
1047) -> usize {
1048    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1049        comm.world_size() / total_num_kv_heads
1050    } else {
1051        1
1052    };
1053    if kv_replicate != 0 {
1054        (num_attention_heads / total_num_kv_heads) / kv_replicate
1055    } else {
1056        num_attention_heads / total_num_kv_heads
1057    }
1058}