mistralrs_quant/distributed/
layers.rs

1use std::sync::Arc;
2
3use candle_core::{Context, Result, Tensor};
4use candle_nn::Linear;
5
6use crate::{
7    blockwise_fp8::blockwise_fp8_linear_b, distributed, gptq::gptq_linear,
8    lora::merge_lora_weights, AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear,
9    GgufMatMul, HqqLayer, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
10    QuantizedSerde, QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
11};
12
13use super::{Comm, SumAllReduce};
14
15fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
16    Shard::Simple {
17        dim,
18        rank,
19        world_size,
20    }
21}
22
23/// This layer has a weight that is parallelized along the input dimension,
24/// returning the "full" output dimension.
25#[derive(Debug)]
26pub struct RowParallelLayer {
27    weight: Arc<dyn QuantMethod>,
28    bias: Option<Tensor>,
29    all_reduce: distributed::SumAllReduce,
30}
31
32impl RowParallelLayer {
33    #[allow(clippy::new_ret_no_self)]
34    pub fn new(
35        in_dim: usize,
36        out_dim: usize,
37        config: &Option<QuantizedConfig>,
38        bias: bool,
39        comm: &Arc<crate::Comm>,
40        vb: ShardedVarBuilder,
41    ) -> Result<Arc<dyn QuantMethod>> {
42        let rank = comm.rank();
43        let world_size = comm.world_size();
44        let shard = shard(1, rank, world_size);
45
46        let weight = if let Some(quant_conf) = &config {
47            // GPTQ and BNB do not support tensor parallelism
48            if matches!(
49                quant_conf,
50                QuantizedConfig::Gptq { .. }
51                    | QuantizedConfig::Bitsandbytes { .. }
52                    | QuantizedConfig::Afq { .. }
53            ) && comm.world_size() != 1
54            {
55                candle_core::bail!(
56                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
57                    comm.world_size()
58                );
59            }
60
61            match quant_conf {
62                QuantizedConfig::Gptq { .. } => {
63                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
64                }
65                QuantizedConfig::Fp8 { .. } => {
66                    // NOTE: no bias for fp8 as it might be parallelized
67                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
68                }
69                QuantizedConfig::Bitsandbytes { .. } => {
70                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
71                }
72                QuantizedConfig::Afq { .. } => {
73                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
74                }
75            }
76        } else {
77            // Handle the case where the layer is dummy (no tensors)
78            if !vb.contains_tensor("weight") {
79                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
80                Arc::new(layer) as Arc<dyn QuantMethod>
81            } else {
82                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
83                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
84
85                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
86                    Linear::new(weight, None),
87                ))?;
88                Arc::new(layer) as Arc<dyn QuantMethod>
89            }
90        };
91
92        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
93        let bias = if bias && vb.contains_tensor("bias") {
94            Some(vb.get((out_dim,), "bias")?)
95        } else {
96            None
97        };
98
99        Ok(Arc::new(Self {
100            weight,
101            bias,
102            all_reduce: distributed::SumAllReduce::new(comm),
103        }))
104    }
105}
106
107impl QuantMethod for RowParallelLayer {
108    fn new(_method: QuantMethodConfig) -> Result<Self>
109    where
110        Self: Sized,
111    {
112        candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
113    }
114
115    fn forward(&self, a: &Tensor) -> Result<Tensor> {
116        let mut xs = self.weight.forward(a)?;
117        xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
118        if let Some(bias) = &self.bias {
119            xs = xs.broadcast_add(bias)?;
120        }
121        Ok(xs)
122    }
123
124    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
125        let weight = self.weight.add_delta_w(delta)?;
126        Ok(Arc::new(Self {
127            weight,
128            bias: self.bias.clone(),
129            all_reduce: self.all_reduce.clone(),
130        }))
131    }
132
133    fn dequantize_w(&self) -> Result<Tensor> {
134        self.weight.dequantize_w()
135    }
136
137    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
138        self.weight.dtype_and_device()
139    }
140
141    fn begin_track_stats(&mut self) -> Result<()> {
142        Arc::get_mut(&mut self.weight)
143            .context("Failed to get &mut to weight")?
144            .begin_track_stats()
145    }
146
147    fn end_track_stats(&self) -> Result<Tensor> {
148        self.weight.end_track_stats()
149    }
150
151    fn quantized_act_type(&self) -> Option<candle_core::DType> {
152        self.weight.quantized_act_type()
153    }
154
155    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
156        self.weight.unquant_weight_bias()
157    }
158
159    fn apply_isq(
160        self: Arc<Self>,
161        dtype: Option<crate::IsqType>,
162        device: candle_core::Device,
163        n_quantized: &std::sync::atomic::AtomicUsize,
164        imatrix_weight: Option<Vec<f32>>,
165        guard: QuantizeOntoGuard,
166    ) -> Result<Arc<dyn QuantMethod>> {
167        let weight =
168            self.weight
169                .clone()
170                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
171        let bias = match &self.bias {
172            Some(b) => {
173                let (dtype, device) = weight.dtype_and_device();
174                Some(b.to_device(&device)?.to_dtype(dtype)?)
175            }
176            None => None,
177        };
178        Ok(Arc::new(Self {
179            weight,
180            bias,
181            all_reduce: self.all_reduce.clone(),
182        }))
183    }
184
185    fn is_distributed(&self) -> Option<DistributedKind> {
186        Some(DistributedKind::RowParallel)
187    }
188}
189
190impl QuantizedSerde for RowParallelLayer {
191    fn isq_serde_supported(&self) -> bool {
192        self.weight.isq_serde_supported()
193    }
194    fn name(&self) -> &'static str {
195        self.weight.name()
196    }
197    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
198        self.weight.serialize_with_bias(self.bias.clone())
199    }
200    fn deserialize(
201        data: std::borrow::Cow<[u8]>,
202        device: &candle_core::Device,
203        comm: &Arc<crate::Comm>,
204        guard: QuantizeOntoGuard,
205    ) -> Result<Arc<dyn QuantMethod>>
206    where
207        Self: Sized,
208    {
209        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
210        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
211        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
212            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
213            QuantizedSerdeType::Unquant => {
214                UnquantLinear::deserialize_ext_bias(data, device, guard)?
215            }
216            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
217            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
218            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
219        };
220        Ok(Arc::new(Self {
221            weight,
222            bias,
223            all_reduce: SumAllReduce::new(comm),
224        }))
225    }
226}
227
228#[derive(Debug)]
229/// This layer has a weight that is parallelized along the output dimension,
230/// taking the "full" input dimension.
231pub struct ColumnParallelLayer {
232    weight: Arc<dyn QuantMethod>,
233    bias: Option<Tensor>,
234}
235
236impl ColumnParallelLayer {
237    #[allow(clippy::new_ret_no_self)]
238    pub fn new_with_shard(
239        in_dim: usize,
240        out_dim: usize,
241        config: &Option<QuantizedConfig>,
242        bias: bool,
243        comm: &Arc<crate::Comm>,
244        shard: Shard,
245        vb: ShardedVarBuilder,
246    ) -> Result<Arc<dyn QuantMethod>> {
247        let weight = if let Some(quant_conf) = &config {
248            // GPTQ and BNB do not support tensor parallelism
249            if matches!(
250                quant_conf,
251                QuantizedConfig::Gptq { .. }
252                    | QuantizedConfig::Bitsandbytes { .. }
253                    | QuantizedConfig::Afq { .. }
254            ) && comm.world_size() != 1
255            {
256                candle_core::bail!(
257                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
258                    comm.world_size()
259                );
260            }
261
262            match quant_conf {
263                QuantizedConfig::Gptq { .. } => {
264                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
265                }
266                QuantizedConfig::Fp8 { .. } => {
267                    // NOTE: no bias for fp8 as it might be parallelized
268                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
269                }
270                QuantizedConfig::Bitsandbytes { .. } => {
271                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
272                }
273                QuantizedConfig::Afq { .. } => {
274                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
275                }
276            }
277        } else {
278            // Handle the case where the layer is dummy (no tensors)
279            if !vb.contains_tensor("weight") {
280                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
281                Arc::new(layer) as Arc<dyn QuantMethod>
282            } else {
283                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
284                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
285
286                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
287                    Linear::new(weight, None),
288                ))?;
289                Arc::new(layer) as Arc<dyn QuantMethod>
290            }
291        };
292
293        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
294        let bias = if bias && vb.contains_tensor("bias") {
295            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
296        } else {
297            None
298        };
299
300        Ok(Arc::new(Self { weight, bias }))
301    }
302
303    #[allow(clippy::new_ret_no_self)]
304    pub fn new(
305        in_dim: usize,
306        out_dim: usize,
307        config: &Option<QuantizedConfig>,
308        bias: bool,
309        comm: &Arc<crate::Comm>,
310        vb: ShardedVarBuilder,
311    ) -> Result<Arc<dyn QuantMethod>> {
312        let rank = comm.rank();
313        let world_size = comm.world_size();
314        let shard = shard(0, rank, world_size);
315
316        Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
317    }
318}
319
320impl QuantMethod for ColumnParallelLayer {
321    fn new(_method: QuantMethodConfig) -> Result<Self>
322    where
323        Self: Sized,
324    {
325        candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
326    }
327
328    fn forward(&self, a: &Tensor) -> Result<Tensor> {
329        let mut xs = self.weight.forward(a)?;
330        if let Some(bias) = &self.bias {
331            xs = xs.broadcast_add(bias)?;
332        }
333        Ok(xs)
334    }
335
336    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
337        let weight = self.weight.add_delta_w(delta)?;
338        Ok(Arc::new(Self {
339            weight,
340            bias: self.bias.clone(),
341        }))
342    }
343
344    fn dequantize_w(&self) -> Result<Tensor> {
345        self.weight.dequantize_w()
346    }
347
348    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
349        self.weight.dtype_and_device()
350    }
351
352    fn begin_track_stats(&mut self) -> Result<()> {
353        Arc::get_mut(&mut self.weight)
354            .context("Failed to get &mut to weight")?
355            .begin_track_stats()
356    }
357
358    fn end_track_stats(&self) -> Result<Tensor> {
359        self.weight.end_track_stats()
360    }
361
362    fn quantized_act_type(&self) -> Option<candle_core::DType> {
363        self.weight.quantized_act_type()
364    }
365
366    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
367        self.weight.unquant_weight_bias()
368    }
369
370    fn apply_isq(
371        self: Arc<Self>,
372        dtype: Option<crate::IsqType>,
373        device: candle_core::Device,
374        n_quantized: &std::sync::atomic::AtomicUsize,
375        imatrix_weight: Option<Vec<f32>>,
376        guard: QuantizeOntoGuard,
377    ) -> Result<Arc<dyn QuantMethod>> {
378        let weight =
379            self.weight
380                .clone()
381                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
382        let bias = match &self.bias {
383            Some(b) => {
384                let (dtype, device) = weight.dtype_and_device();
385                Some(b.to_device(&device)?.to_dtype(dtype)?)
386            }
387            None => None,
388        };
389        Ok(Arc::new(Self { weight, bias }))
390    }
391
392    fn is_distributed(&self) -> Option<DistributedKind> {
393        Some(DistributedKind::ColumnParallel)
394    }
395}
396
397impl QuantizedSerde for ColumnParallelLayer {
398    fn isq_serde_supported(&self) -> bool {
399        self.weight.isq_serde_supported()
400    }
401    fn name(&self) -> &'static str {
402        self.weight.name()
403    }
404    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
405        self.weight.serialize_with_bias(self.bias.clone())
406    }
407    fn deserialize(
408        data: std::borrow::Cow<[u8]>,
409        device: &candle_core::Device,
410        _comm: &Arc<crate::Comm>,
411        guard: QuantizeOntoGuard,
412    ) -> Result<Arc<dyn QuantMethod>>
413    where
414        Self: Sized,
415    {
416        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
417        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
418        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
419            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
420            QuantizedSerdeType::Unquant => {
421                UnquantLinear::deserialize_ext_bias(data, device, guard)?
422            }
423            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
424            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
425            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
426        };
427        Ok(Arc::new(Self { weight, bias }))
428    }
429}
430
431#[derive(Debug)]
432/// This layer has no parallelization
433pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
434
435impl ReplicatedLayer {
436    pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
437        Ok(Arc::new(UnquantLinear::new(
438            QuantMethodConfig::Unquantized(lin),
439        )?))
440    }
441
442    #[allow(clippy::new_ret_no_self)]
443    pub fn new(
444        in_dim: usize,
445        out_dim: usize,
446        config: &Option<QuantizedConfig>,
447        bias: bool,
448        vb: ShardedVarBuilder,
449    ) -> Result<Arc<dyn QuantMethod>> {
450        let layer = if let Some(quant_conf) = &config {
451            match quant_conf {
452                QuantizedConfig::Gptq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
453                QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
454                    in_dim,
455                    out_dim,
456                    quant_conf,
457                    bias,
458                    Default::default(),
459                    vb,
460                )?,
461                QuantizedConfig::Bitsandbytes { .. } => {
462                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb)?) as Arc<_>
463                }
464                QuantizedConfig::Afq { .. } => {
465                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
466                }
467            }
468        } else {
469            // Handle the case where the layer is dummy (no tensors)
470            if !vb.contains_tensor("weight") {
471                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
472                Arc::new(layer) as Arc<dyn QuantMethod>
473            } else {
474                let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
475                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
476
477                let bias = if bias {
478                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
479                } else {
480                    None
481                };
482                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
483                    Linear::new(weight, bias),
484                ))?;
485                Arc::new(layer) as Arc<dyn QuantMethod>
486            }
487        };
488
489        Ok(Arc::new(Self(layer)))
490    }
491}
492
493impl QuantMethod for ReplicatedLayer {
494    fn new(_method: QuantMethodConfig) -> Result<Self>
495    where
496        Self: Sized,
497    {
498        candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
499    }
500
501    fn forward(&self, a: &Tensor) -> Result<Tensor> {
502        self.0.forward(a)
503    }
504
505    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
506        self.0.add_delta_w(delta)
507    }
508
509    fn dequantize_w(&self) -> Result<Tensor> {
510        self.0.dequantize_w()
511    }
512
513    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
514        self.0.dtype_and_device()
515    }
516
517    fn begin_track_stats(&mut self) -> Result<()> {
518        Arc::get_mut(&mut self.0)
519            .context("Failed to get &mut to weight")?
520            .begin_track_stats()
521    }
522
523    fn end_track_stats(&self) -> Result<Tensor> {
524        self.0.end_track_stats()
525    }
526
527    fn quantized_act_type(&self) -> Option<candle_core::DType> {
528        self.0.quantized_act_type()
529    }
530
531    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
532        self.0.unquant_weight_bias()
533    }
534
535    fn apply_isq(
536        self: Arc<Self>,
537        dtype: Option<crate::IsqType>,
538        device: candle_core::Device,
539        n_quantized: &std::sync::atomic::AtomicUsize,
540        imatrix_weight: Option<Vec<f32>>,
541        guard: QuantizeOntoGuard,
542    ) -> Result<Arc<dyn QuantMethod>> {
543        self.0
544            .clone()
545            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
546    }
547
548    fn is_distributed(&self) -> Option<DistributedKind> {
549        Some(DistributedKind::Replicated)
550    }
551}
552
553impl QuantizedSerde for ReplicatedLayer {
554    fn isq_serde_supported(&self) -> bool {
555        self.0.isq_serde_supported()
556    }
557    fn name(&self) -> &'static str {
558        self.0.name()
559    }
560    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
561        self.0.serialize()
562    }
563    fn deserialize(
564        data: std::borrow::Cow<[u8]>,
565        device: &candle_core::Device,
566        comm: &Arc<crate::Comm>,
567        guard: QuantizeOntoGuard,
568    ) -> Result<Arc<dyn QuantMethod>>
569    where
570        Self: Sized,
571    {
572        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
573        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
574        let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
575            QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
576            QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
577            QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
578            QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
579            QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
580        };
581        Ok(Arc::new(Self(deserialized)))
582    }
583}
584
585#[derive(Debug)]
586pub struct PackedExperts {
587    pub gate_proj: Vec<Arc<dyn QuantMethod>>,
588    pub up_proj: Vec<Arc<dyn QuantMethod>>,
589    pub down_proj: Vec<Arc<dyn QuantMethod>>,
590}
591
592impl PackedExperts {
593    /// Note: we only support AFQ and unquantized here because they are the only ones that support indexed.
594    #[allow(clippy::too_many_arguments)]
595    pub fn new(
596        num_local_experts: usize,
597        hidden_size: usize,
598        intermediate_size: usize,
599        config: &Option<QuantizedConfig>,
600        bias: bool,
601        comm: &Arc<crate::Comm>,
602        vb: ShardedVarBuilder,
603    ) -> Result<Self> {
604        if bias {
605            candle_core::bail!("PackedExperts does not support bias.");
606        }
607
608        let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
609            // GPTQ and BNB do not support tensor parallelism
610            if comm.world_size() != 1 {
611                candle_core::bail!(
612                    "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
613                    comm.world_size()
614                );
615            }
616
617            match quant_conf {
618                QuantizedConfig::Afq { .. } => {
619                    if !vb.contains_tensor("gate_up_proj")
620                        || !vb.contains_tensor("gate_up_proj.weight")
621                    {
622                        candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
623                    }
624
625                    (
626                        vec![AfqLayer::afq_packed_linear_b(
627                            num_local_experts,
628                            hidden_size,
629                            intermediate_size,
630                            quant_conf,
631                            bias,
632                            vb.pp("gate_proj"),
633                        )?],
634                        vec![AfqLayer::afq_packed_linear_b(
635                            num_local_experts,
636                            hidden_size,
637                            intermediate_size,
638                            quant_conf,
639                            bias,
640                            vb.pp("up_proj"),
641                        )?],
642                        vec![AfqLayer::afq_packed_linear_b(
643                            num_local_experts,
644                            intermediate_size,
645                            hidden_size,
646                            quant_conf,
647                            bias,
648                            vb.pp("down_proj"),
649                        )?],
650                    )
651                }
652                _ => candle_core::bail!(
653                    "PackedExperts with quantization config only allows AFQ quantization"
654                ),
655            }
656        } else if !vb.contains_tensor("down_proj") {
657            // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
658            let mut gs = Vec::new();
659            let mut us = Vec::new();
660            let mut ds = Vec::new();
661            for _ in 0..num_local_experts {
662                let gate_proj = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
663                let up_proj = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
664                let down_proj = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
665                gs.push(Arc::new(gate_proj) as Arc<dyn QuantMethod>);
666                us.push(Arc::new(up_proj) as Arc<dyn QuantMethod>);
667                ds.push(Arc::new(down_proj) as Arc<dyn QuantMethod>);
668            }
669            (gs, us, ds)
670        } else {
671            // Parallelized like:
672            // Each gpu holds all experts.
673            // Gate/Up proj is parallelized on dim 2 (column)
674            // Down proj is parallelized on dim 1 (row)
675            // All reduce at the end.
676
677            // Handle the case where the layer is dummy (no tensors)
678            let gate_up_block_size = intermediate_size / comm.world_size();
679            let gate_up_start = gate_up_block_size * comm.rank();
680
681            // Gate is right before Up in the gate_up
682            let shard_gate = Shard::Offset {
683                dim: 2,
684                offset: gate_up_start,
685                len: gate_up_block_size,
686            };
687            let shard_up = Shard::Offset {
688                dim: 2,
689                offset: intermediate_size + gate_up_start,
690                len: gate_up_block_size,
691            };
692            let shard_down = Shard::Simple {
693                dim: 1,
694                rank: comm.rank(),
695                world_size: comm.world_size(),
696            };
697
698            let gate_proj = vb
699                .get_with_hints(
700                    (num_local_experts, hidden_size, intermediate_size * 2),
701                    "gate_up_proj",
702                    shard_gate,
703                )?
704                .t()?
705                .contiguous()?;
706            let up_proj = vb
707                .get_with_hints(
708                    (num_local_experts, hidden_size, intermediate_size * 2),
709                    "gate_up_proj",
710                    shard_up,
711                )?
712                .t()?
713                .contiguous()?;
714            let down_proj = vb
715                .get_with_hints(
716                    (num_local_experts, intermediate_size, hidden_size),
717                    "down_proj",
718                    shard_down,
719                )?
720                .t()?
721                .contiguous()?;
722
723            let mut gs = Vec::new();
724            let mut us = Vec::new();
725            let mut ds = Vec::new();
726            for ((mut gate_proj, mut up_proj), mut down_proj) in gate_proj
727                .chunk(num_local_experts, 0)?
728                .into_iter()
729                .zip(up_proj.chunk(num_local_experts, 0)?)
730                .zip(down_proj.chunk(num_local_experts, 0)?)
731            {
732                gate_proj = gate_proj.squeeze(0)?;
733                up_proj = up_proj.squeeze(0)?;
734                down_proj = down_proj.squeeze(0)?;
735                let gate_proj = merge_lora_weights(
736                    &vb,
737                    gate_proj,
738                    hidden_size,
739                    intermediate_size * 2,
740                    shard_gate,
741                )?;
742                let up_proj =
743                    merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
744                let down_proj =
745                    merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
746
747                let gate_proj = <UnquantLinear as QuantMethod>::new(
748                    QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
749                )?;
750                let up_proj = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
751                    Linear::new(up_proj, None),
752                ))?;
753                let down_proj = <UnquantLinear as QuantMethod>::new(
754                    QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
755                )?;
756                gs.push(Arc::new(gate_proj) as Arc<dyn QuantMethod>);
757                us.push(Arc::new(up_proj) as Arc<dyn QuantMethod>);
758                ds.push(Arc::new(down_proj) as Arc<dyn QuantMethod>);
759            }
760            (gs, us, ds)
761        };
762
763        Ok(Self {
764            gate_proj,
765            up_proj,
766            down_proj,
767        })
768    }
769}
770
771/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
772pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
773    if comm.world_size() == 1 {
774        return Shard::default();
775    }
776
777    // Tensor parallelism case
778
779    // We may need to replicate the kv heads
780    let kv_replicate = if comm.world_size() > total_num_kv_heads {
781        comm.world_size() / total_num_kv_heads
782    } else {
783        return Shard::Simple {
784            dim: 0,
785            rank: comm.rank(),
786            world_size: comm.world_size(),
787        };
788    };
789
790    let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
791    let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
792    Shard::Offset {
793        dim: 0,
794        offset: kv_shard_id * head_dim,
795        len: head_dim,
796    }
797}
798
799/// Compute the number of KV groups, taking into account KV head replication.
800pub fn compute_n_kv_groups(
801    total_num_kv_heads: usize,
802    num_attention_heads: usize,
803    comm: &Comm,
804) -> usize {
805    let kv_replicate = if comm.world_size() > total_num_kv_heads {
806        comm.world_size() / total_num_kv_heads
807    } else {
808        1
809    };
810    if kv_replicate != 0 {
811        (num_attention_heads / total_num_kv_heads) / kv_replicate
812    } else {
813        num_attention_heads / total_num_kv_heads
814    }
815}