mistralrs_quant/distributed/
layers.rs

1use std::sync::Arc;
2
3use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
4use candle_nn::Linear;
5
6use crate::{
7    blockwise_fp8::blockwise_fp8_linear_b,
8    distributed,
9    gptq::gptq_linear,
10    lora::merge_lora_weights,
11    should_apply_immediate_isq,
12    utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
13    AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, MXFP4Layer,
14    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
15    QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
16};
17
18use super::{Comm, SumAllReduce};
19
20fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
21    Shard::Simple {
22        dim,
23        rank,
24        world_size,
25    }
26}
27
28/// This layer has a weight that is parallelized along the input dimension,
29/// returning the "full" output dimension.
30#[derive(Debug)]
31pub struct RowParallelLayer {
32    weight: Arc<dyn QuantMethod>,
33    bias: Option<Tensor>,
34    all_reduce: distributed::SumAllReduce,
35}
36
37impl RowParallelLayer {
38    #[allow(clippy::new_ret_no_self)]
39    pub fn new(
40        in_dim: usize,
41        out_dim: usize,
42        config: &Option<QuantizedConfig>,
43        bias: bool,
44        comm: &Arc<crate::Comm>,
45        vb: ShardedVarBuilder,
46    ) -> Result<Arc<dyn QuantMethod>> {
47        let rank = comm.rank();
48        let world_size = comm.world_size();
49        let shard = shard(1, rank, world_size);
50
51        let base_vb = vb.clone();
52        let vb = if should_apply_immediate_isq(&vb) {
53            vb.set_device(Device::Cpu)
54        } else {
55            vb
56        };
57
58        let weight = if let Some(quant_conf) = &config {
59            // GPTQ and BNB do not support tensor parallelism
60            if matches!(
61                quant_conf,
62                QuantizedConfig::GptqAwq { .. }
63                    | QuantizedConfig::Bitsandbytes { .. }
64                    | QuantizedConfig::Afq { .. }
65            ) && comm.world_size() != 1
66            {
67                candle_core::bail!(
68                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
69                    comm.world_size()
70                );
71            }
72
73            match quant_conf {
74                QuantizedConfig::GptqAwq { .. } => {
75                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
76                }
77                QuantizedConfig::Fp8 { .. } => {
78                    // NOTE: no bias for fp8 as it might be parallelized
79                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
80                }
81                QuantizedConfig::Bitsandbytes { .. } => {
82                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
83                }
84                QuantizedConfig::Afq { .. } => {
85                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
86                }
87                QuantizedConfig::MXFP4 {} => {
88                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
89                }
90            }
91        } else {
92            // Handle the case where the layer is dummy (no tensors)
93            if !vb.contains_tensor("weight") {
94                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
95                Arc::new(layer) as Arc<dyn QuantMethod>
96            } else {
97                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
98                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
99
100                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
101                    Linear::new(weight, None),
102                ))?;
103                Arc::new(layer) as Arc<dyn QuantMethod>
104            }
105        };
106
107        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
108        let bias = if bias && vb.contains_tensor("bias") {
109            Some(vb.get((out_dim,), "bias")?)
110        } else {
111            None
112        };
113
114        let this_unquant = Arc::new(Self {
115            weight,
116            bias,
117            all_reduce: distributed::SumAllReduce::new(comm),
118        });
119        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
120        Ok(this)
121    }
122
123    #[allow(clippy::new_ret_no_self)]
124    pub fn new_matformer(
125        in_dim: usize,
126        out_dim: usize,
127        orig_intermediate_size: usize,
128        config: &Option<QuantizedConfig>,
129        bias: bool,
130        comm: &Arc<crate::Comm>,
131        vb: ShardedVarBuilder,
132    ) -> Result<Arc<dyn QuantMethod>> {
133        let rank = comm.rank();
134        let world_size = comm.world_size();
135        let shard = shard(1, rank, world_size);
136
137        let base_vb = vb.clone();
138        let vb = if should_apply_immediate_isq(&vb) {
139            vb.set_device(Device::Cpu)
140        } else {
141            vb
142        };
143
144        if config.is_some() {
145            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
146        }
147
148        // Handle the case where the layer is dummy (no tensors)
149        let weight = if !vb.contains_tensor("weight") {
150            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
151            Arc::new(layer) as Arc<dyn QuantMethod>
152        } else {
153            let weight = vb
154                .get_with_hints(
155                    (out_dim, orig_intermediate_size),
156                    "weight",
157                    Default::default(),
158                )?
159                .i((.., ..in_dim))?
160                .contiguous()?;
161
162            let weight = shard.apply_to(&weight)?;
163            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
164
165            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
166                Linear::new(weight, None),
167            ))?;
168            Arc::new(layer) as Arc<dyn QuantMethod>
169        };
170
171        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
172        let bias = if bias && vb.contains_tensor("bias") {
173            Some(vb.get((out_dim,), "bias")?)
174        } else {
175            None
176        };
177
178        let this_unquant = Arc::new(Self {
179            weight,
180            bias,
181            all_reduce: distributed::SumAllReduce::new(comm),
182        });
183        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
184        Ok(this)
185    }
186}
187
188impl QuantMethod for RowParallelLayer {
189    fn new(_method: QuantMethodConfig) -> Result<Self>
190    where
191        Self: Sized,
192    {
193        candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
194    }
195
196    fn forward(&self, a: &Tensor) -> Result<Tensor> {
197        let mut xs = self.weight.forward(a)?;
198        xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
199        if let Some(bias) = &self.bias {
200            xs = xs.broadcast_add(bias)?;
201        }
202        Ok(xs)
203    }
204
205    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
206        let weight = self.weight.add_delta_w(delta)?;
207        Ok(Arc::new(Self {
208            weight,
209            bias: self.bias.clone(),
210            all_reduce: self.all_reduce.clone(),
211        }))
212    }
213
214    fn dequantize_w(&self) -> Result<Tensor> {
215        self.weight.dequantize_w()
216    }
217
218    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
219        self.weight.dtype_and_device()
220    }
221
222    fn begin_track_stats(&mut self) -> Result<()> {
223        Arc::get_mut(&mut self.weight)
224            .context("Failed to get &mut to weight")?
225            .begin_track_stats()
226    }
227
228    fn end_track_stats(&self) -> Result<Tensor> {
229        self.weight.end_track_stats()
230    }
231
232    fn quantized_act_type(&self) -> Option<candle_core::DType> {
233        self.weight.quantized_act_type()
234    }
235
236    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
237        self.weight.unquant_weight_bias()
238    }
239
240    fn apply_isq(
241        self: Arc<Self>,
242        dtype: Option<crate::IsqType>,
243        device: candle_core::Device,
244        n_quantized: &std::sync::atomic::AtomicUsize,
245        imatrix_weight: Option<Vec<f32>>,
246        guard: QuantizeOntoGuard,
247    ) -> Result<Arc<dyn QuantMethod>> {
248        let weight =
249            self.weight
250                .clone()
251                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
252        let bias = match &self.bias {
253            Some(b) => {
254                let (dtype, device) = weight.dtype_and_device();
255                Some(b.to_device(&device)?.to_dtype(dtype)?)
256            }
257            None => None,
258        };
259        Ok(Arc::new(Self {
260            weight,
261            bias,
262            all_reduce: self.all_reduce.clone(),
263        }))
264    }
265
266    fn is_distributed(&self) -> Option<DistributedKind> {
267        Some(DistributedKind::RowParallel)
268    }
269}
270
271impl QuantizedSerde for RowParallelLayer {
272    fn isq_serde_supported(&self) -> bool {
273        self.weight.isq_serde_supported()
274    }
275    fn name(&self) -> &'static str {
276        self.weight.name()
277    }
278    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
279        self.weight.serialize_with_bias(self.bias.clone())
280    }
281    fn deserialize(
282        data: std::borrow::Cow<[u8]>,
283        device: &candle_core::Device,
284        comm: &Arc<crate::Comm>,
285        guard: QuantizeOntoGuard,
286    ) -> Result<Arc<dyn QuantMethod>>
287    where
288        Self: Sized,
289    {
290        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
291        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
292        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
293            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
294            QuantizedSerdeType::Unquant => {
295                UnquantLinear::deserialize_ext_bias(data, device, guard)?
296            }
297            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
298            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
299            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
300        };
301        Ok(Arc::new(Self {
302            weight,
303            bias,
304            all_reduce: SumAllReduce::new(comm),
305        }))
306    }
307}
308
309#[derive(Debug)]
310/// This layer has a weight that is parallelized along the output dimension,
311/// taking the "full" input dimension.
312pub struct ColumnParallelLayer {
313    weight: Arc<dyn QuantMethod>,
314    bias: Option<Tensor>,
315}
316
317impl ColumnParallelLayer {
318    #[allow(clippy::new_ret_no_self)]
319    pub fn new_with_shard(
320        in_dim: usize,
321        out_dim: usize,
322        config: &Option<QuantizedConfig>,
323        bias: bool,
324        comm: &Arc<crate::Comm>,
325        shard: Shard,
326        vb: ShardedVarBuilder,
327    ) -> Result<Arc<dyn QuantMethod>> {
328        let base_vb = vb.clone();
329        let vb = if should_apply_immediate_isq(&vb) {
330            vb.set_device(Device::Cpu)
331        } else {
332            vb
333        };
334
335        let weight = if let Some(quant_conf) = &config {
336            // GPTQ and BNB do not support tensor parallelism
337            if matches!(
338                quant_conf,
339                QuantizedConfig::GptqAwq { .. }
340                    | QuantizedConfig::Bitsandbytes { .. }
341                    | QuantizedConfig::Afq { .. }
342            ) && comm.world_size() != 1
343            {
344                candle_core::bail!(
345                    "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
346                    comm.world_size()
347                );
348            }
349
350            match quant_conf {
351                QuantizedConfig::GptqAwq { .. } => {
352                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
353                }
354                QuantizedConfig::Fp8 { .. } => {
355                    // NOTE: no bias for fp8 as it might be parallelized
356                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
357                }
358                QuantizedConfig::Bitsandbytes { .. } => {
359                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
360                }
361                QuantizedConfig::Afq { .. } => {
362                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
363                }
364                QuantizedConfig::MXFP4 {} => {
365                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
366                }
367            }
368        } else {
369            // Handle the case where the layer is dummy (no tensors)
370            if !vb.contains_tensor("weight") {
371                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
372                Arc::new(layer) as Arc<dyn QuantMethod>
373            } else {
374                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
375                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
376
377                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
378                    Linear::new(weight, None),
379                ))?;
380                Arc::new(layer) as Arc<dyn QuantMethod>
381            }
382        };
383
384        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
385        let bias = if bias && vb.contains_tensor("bias") {
386            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
387        } else {
388            None
389        };
390
391        let this_unquant = Arc::new(Self { weight, bias });
392        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
393        Ok(this)
394    }
395
396    #[allow(clippy::new_ret_no_self)]
397    pub fn new(
398        in_dim: usize,
399        out_dim: usize,
400        config: &Option<QuantizedConfig>,
401        bias: bool,
402        comm: &Arc<crate::Comm>,
403        vb: ShardedVarBuilder,
404    ) -> Result<Arc<dyn QuantMethod>> {
405        let rank = comm.rank();
406        let world_size = comm.world_size();
407        let shard = shard(0, rank, world_size);
408
409        Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
410    }
411
412    #[allow(clippy::new_ret_no_self)]
413    pub fn new_matformer(
414        in_dim: usize,
415        out_dim: usize,
416        orig_intermediate_size: usize,
417        config: &Option<QuantizedConfig>,
418        bias: bool,
419        comm: &Arc<crate::Comm>,
420        vb: ShardedVarBuilder,
421    ) -> Result<Arc<dyn QuantMethod>> {
422        let rank = comm.rank();
423        let world_size = comm.world_size();
424        let shard = shard(0, rank, world_size);
425
426        let base_vb = vb.clone();
427        let vb = if should_apply_immediate_isq(&vb) {
428            vb.set_device(Device::Cpu)
429        } else {
430            vb
431        };
432
433        if config.is_some() {
434            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
435        }
436
437        // Handle the case where the layer is dummy (no tensors)
438        let weight = if !vb.contains_tensor("weight") {
439            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
440            Arc::new(layer) as Arc<dyn QuantMethod>
441        } else {
442            let weight = vb
443                .get_with_hints(
444                    (orig_intermediate_size, in_dim),
445                    "weight",
446                    Default::default(),
447                )?
448                .i((..out_dim, ..))?
449                .contiguous()?;
450
451            let weight = shard.apply_to(&weight)?;
452            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
453
454            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
455                Linear::new(weight, None),
456            ))?;
457            Arc::new(layer) as Arc<dyn QuantMethod>
458        };
459
460        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
461        let bias = if bias && vb.contains_tensor("bias") {
462            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
463        } else {
464            None
465        };
466
467        let this_unquant = Arc::new(Self { weight, bias });
468        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
469        Ok(this)
470    }
471
472    pub fn new_merged(
473        in_dim: usize,
474        out_dim: usize,
475        chunks: usize,
476        config: &Option<QuantizedConfig>,
477        bias: bool,
478        comm: &Arc<crate::Comm>,
479        vb: ShardedVarBuilder,
480    ) -> Result<Vec<Arc<dyn QuantMethod>>> {
481        let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
482        for chunk_idx in 0..chunks {
483            let layer = ColumnParallelLayer::new_with_shard(
484                in_dim,
485                out_dim,
486                config,
487                bias,
488                comm,
489                shard(
490                    0,
491                    chunk_idx * comm.world_size() + comm.rank(),
492                    chunks * comm.world_size(),
493                ),
494                vb.clone(),
495            )?;
496            vec_layers.push(layer);
497        }
498        Ok(vec_layers)
499    }
500}
501
502impl QuantMethod for ColumnParallelLayer {
503    fn new(_method: QuantMethodConfig) -> Result<Self>
504    where
505        Self: Sized,
506    {
507        candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
508    }
509
510    fn forward(&self, a: &Tensor) -> Result<Tensor> {
511        let mut xs = self.weight.forward(a)?;
512        if let Some(bias) = &self.bias {
513            xs = xs.broadcast_add(bias)?;
514        }
515        Ok(xs)
516    }
517
518    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
519        let weight = self.weight.add_delta_w(delta)?;
520        Ok(Arc::new(Self {
521            weight,
522            bias: self.bias.clone(),
523        }))
524    }
525
526    fn dequantize_w(&self) -> Result<Tensor> {
527        self.weight.dequantize_w()
528    }
529
530    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
531        self.weight.dtype_and_device()
532    }
533
534    fn begin_track_stats(&mut self) -> Result<()> {
535        Arc::get_mut(&mut self.weight)
536            .context("Failed to get &mut to weight")?
537            .begin_track_stats()
538    }
539
540    fn end_track_stats(&self) -> Result<Tensor> {
541        self.weight.end_track_stats()
542    }
543
544    fn quantized_act_type(&self) -> Option<candle_core::DType> {
545        self.weight.quantized_act_type()
546    }
547
548    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
549        self.weight.unquant_weight_bias()
550    }
551
552    fn apply_isq(
553        self: Arc<Self>,
554        dtype: Option<crate::IsqType>,
555        device: candle_core::Device,
556        n_quantized: &std::sync::atomic::AtomicUsize,
557        imatrix_weight: Option<Vec<f32>>,
558        guard: QuantizeOntoGuard,
559    ) -> Result<Arc<dyn QuantMethod>> {
560        let weight =
561            self.weight
562                .clone()
563                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
564        let bias = match &self.bias {
565            Some(b) => {
566                let (dtype, device) = weight.dtype_and_device();
567                Some(b.to_device(&device)?.to_dtype(dtype)?)
568            }
569            None => None,
570        };
571        Ok(Arc::new(Self { weight, bias }))
572    }
573
574    fn is_distributed(&self) -> Option<DistributedKind> {
575        Some(DistributedKind::ColumnParallel)
576    }
577}
578
579impl QuantizedSerde for ColumnParallelLayer {
580    fn isq_serde_supported(&self) -> bool {
581        self.weight.isq_serde_supported()
582    }
583    fn name(&self) -> &'static str {
584        self.weight.name()
585    }
586    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
587        self.weight.serialize_with_bias(self.bias.clone())
588    }
589    fn deserialize(
590        data: std::borrow::Cow<[u8]>,
591        device: &candle_core::Device,
592        _comm: &Arc<crate::Comm>,
593        guard: QuantizeOntoGuard,
594    ) -> Result<Arc<dyn QuantMethod>>
595    where
596        Self: Sized,
597    {
598        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
599        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
600        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
601            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
602            QuantizedSerdeType::Unquant => {
603                UnquantLinear::deserialize_ext_bias(data, device, guard)?
604            }
605            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
606            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
607            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
608        };
609        Ok(Arc::new(Self { weight, bias }))
610    }
611}
612
613#[derive(Debug)]
614/// This layer has no parallelization
615pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
616
617impl ReplicatedLayer {
618    pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
619        let dev = lin.weight().device().clone();
620        let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
621        let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
622        Ok(this)
623    }
624
625    #[allow(clippy::new_ret_no_self)]
626    pub fn new(
627        in_dim: usize,
628        out_dim: usize,
629        config: &Option<QuantizedConfig>,
630        bias: bool,
631        vb: ShardedVarBuilder,
632    ) -> Result<Arc<dyn QuantMethod>> {
633        let base_vb = vb.clone();
634        let vb = if should_apply_immediate_isq(&vb) {
635            vb.set_device(Device::Cpu)
636        } else {
637            vb
638        };
639
640        let layer = if let Some(quant_conf) = &config {
641            match quant_conf {
642                QuantizedConfig::GptqAwq { .. } => {
643                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
644                }
645                QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
646                    in_dim,
647                    out_dim,
648                    quant_conf,
649                    bias,
650                    Default::default(),
651                    vb.clone(),
652                )?,
653                QuantizedConfig::Bitsandbytes { .. } => {
654                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
655                }
656                QuantizedConfig::Afq { .. } => {
657                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
658                }
659                QuantizedConfig::MXFP4 {} => {
660                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
661                }
662            }
663        } else {
664            // Handle the case where the layer is dummy (no tensors)
665            if !vb.contains_tensor("weight") {
666                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
667                Arc::new(layer) as Arc<dyn QuantMethod>
668            } else {
669                let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
670                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
671
672                let bias = if bias {
673                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
674                } else {
675                    None
676                };
677                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
678                    Linear::new(weight, bias),
679                ))?;
680                Arc::new(layer) as Arc<dyn QuantMethod>
681            }
682        };
683
684        let this_unquant = Arc::new(Self(layer));
685        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
686        Ok(this)
687    }
688
689    #[allow(clippy::new_ret_no_self)]
690    pub fn new_layers_matformer_indices(
691        in_dim: usize,
692        out_dim: usize,
693        kept_layers_indices: Option<&Tensor>,
694        orig_num_hidden_layers: usize,
695        config: &Option<QuantizedConfig>,
696        bias: bool,
697        vb: ShardedVarBuilder,
698    ) -> Result<Arc<dyn QuantMethod>> {
699        let base_vb = vb.clone();
700        let vb = if should_apply_immediate_isq(&vb) {
701            vb.set_device(Device::Cpu)
702        } else {
703            vb
704        };
705
706        let layer = if let Some(quant_conf) = &config {
707            if kept_layers_indices.is_some() {
708                candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
709            }
710
711            match quant_conf {
712                QuantizedConfig::GptqAwq { .. } => {
713                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
714                }
715                QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
716                    in_dim,
717                    out_dim,
718                    quant_conf,
719                    bias,
720                    Default::default(),
721                    vb.clone(),
722                )?,
723                QuantizedConfig::Bitsandbytes { .. } => {
724                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
725                }
726                QuantizedConfig::Afq { .. } => {
727                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
728                }
729                QuantizedConfig::MXFP4 {} => {
730                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
731                }
732            }
733        } else {
734            // Handle the case where the layer is dummy (no tensors)
735            if !vb.contains_tensor("weight") {
736                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
737                Arc::new(layer) as Arc<dyn QuantMethod>
738            } else {
739                let mut weight =
740                    vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
741
742                if let Some(kept_layers_indices) = &kept_layers_indices {
743                    let weight_reshaped = weight.reshape((
744                        orig_num_hidden_layers,
745                        weight.dim(0)? / orig_num_hidden_layers,
746                        weight.dim(1)?,
747                    ))?;
748
749                    weight = weight_reshaped
750                        .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
751                        .reshape(((), weight_reshaped.dim(D::Minus1)?))?
752                        .contiguous()?;
753                }
754
755                weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
756
757                let bias = if bias {
758                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
759                } else {
760                    None
761                };
762                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
763                    Linear::new(weight, bias),
764                ))?;
765                Arc::new(layer) as Arc<dyn QuantMethod>
766            }
767        };
768
769        let this_unquant = Arc::new(Self(layer));
770        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
771        Ok(this)
772    }
773}
774
775impl QuantMethod for ReplicatedLayer {
776    fn new(_method: QuantMethodConfig) -> Result<Self>
777    where
778        Self: Sized,
779    {
780        candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
781    }
782
783    fn forward(&self, a: &Tensor) -> Result<Tensor> {
784        self.0.forward(a)
785    }
786
787    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
788        self.0.add_delta_w(delta)
789    }
790
791    fn dequantize_w(&self) -> Result<Tensor> {
792        self.0.dequantize_w()
793    }
794
795    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
796        self.0.dtype_and_device()
797    }
798
799    fn begin_track_stats(&mut self) -> Result<()> {
800        Arc::get_mut(&mut self.0)
801            .context("Failed to get &mut to weight")?
802            .begin_track_stats()
803    }
804
805    fn end_track_stats(&self) -> Result<Tensor> {
806        self.0.end_track_stats()
807    }
808
809    fn quantized_act_type(&self) -> Option<candle_core::DType> {
810        self.0.quantized_act_type()
811    }
812
813    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
814        self.0.unquant_weight_bias()
815    }
816
817    fn apply_isq(
818        self: Arc<Self>,
819        dtype: Option<crate::IsqType>,
820        device: candle_core::Device,
821        n_quantized: &std::sync::atomic::AtomicUsize,
822        imatrix_weight: Option<Vec<f32>>,
823        guard: QuantizeOntoGuard,
824    ) -> Result<Arc<dyn QuantMethod>> {
825        self.0
826            .clone()
827            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
828    }
829
830    fn is_distributed(&self) -> Option<DistributedKind> {
831        Some(DistributedKind::Replicated)
832    }
833}
834
835impl QuantizedSerde for ReplicatedLayer {
836    fn isq_serde_supported(&self) -> bool {
837        self.0.isq_serde_supported()
838    }
839    fn name(&self) -> &'static str {
840        self.0.name()
841    }
842    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
843        self.0.serialize()
844    }
845    fn deserialize(
846        data: std::borrow::Cow<[u8]>,
847        device: &candle_core::Device,
848        comm: &Arc<crate::Comm>,
849        guard: QuantizeOntoGuard,
850    ) -> Result<Arc<dyn QuantMethod>>
851    where
852        Self: Sized,
853    {
854        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
855        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
856        let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
857            QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
858            QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
859            QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
860            QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
861            QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
862        };
863        Ok(Arc::new(Self(deserialized)))
864    }
865}
866
867#[derive(Debug)]
868pub struct PackedExperts {
869    pub gate_proj: Vec<Arc<dyn QuantMethod>>,
870    pub up_proj: Vec<Arc<dyn QuantMethod>>,
871    pub down_proj: Vec<Arc<dyn QuantMethod>>,
872}
873
874impl PackedExperts {
875    /// Note: we only support AFQ and unquantized here because they are the only ones that support indexed.
876    #[allow(clippy::too_many_arguments)]
877    pub fn new(
878        num_local_experts: usize,
879        hidden_size: usize,
880        intermediate_size: usize,
881        config: &Option<QuantizedConfig>,
882        bias: bool,
883        comm: &Arc<crate::Comm>,
884        vb: ShardedVarBuilder,
885    ) -> Result<Self> {
886        if bias {
887            candle_core::bail!("PackedExperts does not support bias.");
888        }
889
890        let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
891            // GPTQ and BNB do not support tensor parallelism
892            if comm.world_size() != 1 {
893                candle_core::bail!(
894                    "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
895                    comm.world_size()
896                );
897            }
898
899            match quant_conf {
900                QuantizedConfig::Afq { .. } => {
901                    if !vb.contains_tensor("gate_up_proj")
902                        || !vb.contains_tensor("gate_up_proj.weight")
903                    {
904                        candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
905                    }
906
907                    let base_vb = vb.clone();
908
909                    let vb_gate_proj = if should_apply_immediate_isq(&vb) {
910                        vb.pp("gate_proj").set_device(Device::Cpu)
911                    } else {
912                        vb.pp("gate_proj")
913                    };
914                    let vb_up_proj = if should_apply_immediate_isq(&vb) {
915                        vb.pp("up_proj").set_device(Device::Cpu)
916                    } else {
917                        vb.pp("up_proj")
918                    };
919                    let vb_down_proj = if should_apply_immediate_isq(&vb) {
920                        vb.pp("down_proj").set_device(Device::Cpu)
921                    } else {
922                        vb.pp("down_proj")
923                    };
924                    let mut gate_proj = AfqLayer::afq_packed_linear_b(
925                        num_local_experts,
926                        hidden_size,
927                        intermediate_size,
928                        quant_conf,
929                        bias,
930                        vb_gate_proj,
931                    )?;
932                    let mut up_proj = AfqLayer::afq_packed_linear_b(
933                        num_local_experts,
934                        hidden_size,
935                        intermediate_size,
936                        quant_conf,
937                        bias,
938                        vb_up_proj,
939                    )?;
940                    let mut down_proj = AfqLayer::afq_packed_linear_b(
941                        num_local_experts,
942                        intermediate_size,
943                        hidden_size,
944                        quant_conf,
945                        bias,
946                        vb_down_proj,
947                    )?;
948
949                    gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
950                    up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
951                    down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
952
953                    (vec![gate_proj], vec![up_proj], vec![down_proj])
954                }
955                _ => candle_core::bail!(
956                    "PackedExperts with quantization config only allows AFQ quantization"
957                ),
958            }
959        } else if !vb.contains_tensor("gate_up_proj") {
960            // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
961            let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
962            let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
963            let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
964            for _ in 0..num_local_experts {
965                gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
966                us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
967                ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
968            }
969            (gs, us, ds)
970        } else {
971            // Parallelized like:
972            // Each gpu holds all experts.
973            // Gate/Up proj is parallelized on dim 2 (column)
974            // Down proj is parallelized on dim 1 (row)
975            // All reduce at the end.
976
977            // Handle the case where the layer is dummy (no tensors)
978            let gate_up_block_size = intermediate_size / comm.world_size();
979            let gate_up_start = gate_up_block_size * comm.rank();
980
981            // Gate is right before Up in the gate_up
982            let shard_gate = Shard::Offset {
983                dim: 2,
984                offset: gate_up_start,
985                len: gate_up_block_size,
986            };
987            let shard_up = Shard::Offset {
988                dim: 2,
989                offset: intermediate_size + gate_up_start,
990                len: gate_up_block_size,
991            };
992            let shard_down = Shard::Simple {
993                dim: 1,
994                rank: comm.rank(),
995                world_size: comm.world_size(),
996            };
997
998            let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
999                vb.pp("gate_up_proj").set_device(Device::Cpu)
1000            } else {
1001                vb.pp("gate_up_proj")
1002            };
1003            let vb_down_proj = if should_apply_immediate_isq(&vb) {
1004                vb.pp("down_proj").set_device(Device::Cpu)
1005            } else {
1006                vb.pp("down_proj")
1007            };
1008
1009            let gate_proj = vb
1010                .get_with_hints(
1011                    (num_local_experts, hidden_size, intermediate_size * 2),
1012                    "gate_up_proj",
1013                    shard_gate,
1014                )?
1015                .t()?
1016                .contiguous()?;
1017            let up_proj = vb
1018                .get_with_hints(
1019                    (num_local_experts, hidden_size, intermediate_size * 2),
1020                    "gate_up_proj",
1021                    shard_up,
1022                )?
1023                .t()?
1024                .contiguous()?;
1025            let down_proj = vb
1026                .get_with_hints(
1027                    (num_local_experts, intermediate_size, hidden_size),
1028                    "down_proj",
1029                    shard_down,
1030                )?
1031                .t()?
1032                .contiguous()?;
1033
1034            let gc = gate_proj.chunk(num_local_experts, 0)?;
1035            let uc = up_proj.chunk(num_local_experts, 0)?;
1036            let dc = down_proj.chunk(num_local_experts, 0)?;
1037            drop((gate_proj, up_proj, down_proj));
1038
1039            let mut gs = Vec::new();
1040            let mut us = Vec::new();
1041            let mut ds = Vec::new();
1042            for ((mut gate_proj, mut up_proj), mut down_proj) in
1043                gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
1044            {
1045                gate_proj = gate_proj.squeeze(0)?;
1046                up_proj = up_proj.squeeze(0)?;
1047                down_proj = down_proj.squeeze(0)?;
1048                let gate_proj = merge_lora_weights(
1049                    &vb,
1050                    gate_proj,
1051                    hidden_size,
1052                    intermediate_size * 2,
1053                    shard_gate,
1054                )?;
1055                let up_proj =
1056                    merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1057                let down_proj =
1058                    merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1059
1060                let mut gate_proj: Arc<dyn QuantMethod> =
1061                    Arc::new(<UnquantLinear as QuantMethod>::new(
1062                        QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1063                    )?);
1064                gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1065                let mut up_proj: Arc<dyn QuantMethod> =
1066                    Arc::new(<UnquantLinear as QuantMethod>::new(
1067                        QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1068                    )?);
1069                up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1070                let mut down_proj: Arc<dyn QuantMethod> =
1071                    Arc::new(<UnquantLinear as QuantMethod>::new(
1072                        QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1073                    )?);
1074                down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1075                gs.push(gate_proj);
1076                us.push(up_proj);
1077                ds.push(down_proj);
1078            }
1079            (gs, us, ds)
1080        };
1081
1082        Ok(Self {
1083            gate_proj,
1084            up_proj,
1085            down_proj,
1086        })
1087    }
1088}
1089
1090pub struct FusedExperts {
1091    pub fused_gate_proj: Arc<dyn QuantMethod>,
1092    pub fused_up_proj: Arc<dyn QuantMethod>,
1093    pub fused_down_proj: Arc<dyn QuantMethod>,
1094}
1095
1096impl FusedExperts {
1097    pub fn new(
1098        hidden_size: usize,
1099        moe_intermediate_size: usize,
1100        num_experts: usize,
1101        quantization_config: &Option<QuantizedConfig>,
1102        vb: ShardedVarBuilder,
1103    ) -> Result<Self> {
1104        if !vb.device().is_metal() {
1105            candle_core::bail!("FastMoeMlp requires Metal.");
1106        }
1107
1108        let (fused_gate_proj, fused_up_proj, fused_down_proj) =
1109            if matches!(&quantization_config, Some(QuantizedConfig::Afq { .. })) {
1110                let quantization_config = quantization_config.as_ref().unwrap();
1111
1112                let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1113                    num_experts,
1114                    hidden_size,
1115                    moe_intermediate_size,
1116                    quantization_config,
1117                    false,
1118                    vb.pp("switch_mlp.gate_proj"),
1119                )?;
1120                let fused_up_proj = AfqLayer::afq_packed_linear_b(
1121                    num_experts,
1122                    hidden_size,
1123                    moe_intermediate_size,
1124                    quantization_config,
1125                    false,
1126                    vb.pp("switch_mlp.up_proj"),
1127                )?;
1128                let fused_down_proj = AfqLayer::afq_packed_linear_b(
1129                    num_experts,
1130                    moe_intermediate_size,
1131                    hidden_size,
1132                    quantization_config,
1133                    false,
1134                    vb.pp("switch_mlp.down_proj"),
1135                )?;
1136
1137                (fused_gate_proj, fused_up_proj, fused_down_proj)
1138            } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1139                let experts_vb = vb.pp("experts");
1140                let mut gate_proj_vec = Vec::new();
1141                let mut up_proj_vec = Vec::new();
1142                let mut down_proj_vec = Vec::new();
1143                for i in 0..num_experts {
1144                    let vb = experts_vb.pp(i);
1145
1146                    let gate_proj = crate::linear_no_bias(
1147                        hidden_size,
1148                        moe_intermediate_size,
1149                        quantization_config,
1150                        vb.pp("gate_proj.weight"),
1151                    )?;
1152                    let up_proj = crate::linear_no_bias(
1153                        hidden_size,
1154                        moe_intermediate_size,
1155                        quantization_config,
1156                        vb.pp("up_proj.weight"),
1157                    )?;
1158                    let down_proj = crate::linear_no_bias(
1159                        moe_intermediate_size,
1160                        hidden_size,
1161                        quantization_config,
1162                        vb.pp("down_proj.weight"),
1163                    )?;
1164
1165                    gate_proj_vec.push(gate_proj.dequantize_w()?);
1166                    up_proj_vec.push(up_proj.dequantize_w()?);
1167                    down_proj_vec.push(down_proj.dequantize_w()?);
1168                }
1169
1170                let mut gate_proj: Arc<dyn QuantMethod> =
1171                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1172                        Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1173                    ))?);
1174                let mut up_proj: Arc<dyn QuantMethod> =
1175                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1176                        Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
1177                    ))?);
1178                let mut down_proj: Arc<dyn QuantMethod> =
1179                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1180                        Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1181                    ))?);
1182                gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1183                up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1184                down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1185
1186                (gate_proj, up_proj, down_proj)
1187            } else {
1188                let experts_vb = vb.pp("experts");
1189                let mut gate_proj_vec = Vec::new();
1190                let mut up_proj_vec = Vec::new();
1191                let mut down_proj_vec = Vec::new();
1192                for i in 0..num_experts {
1193                    let vb = experts_vb.pp(i);
1194                    let gate_proj =
1195                        vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1196                    let up_proj = vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1197                    let down_proj =
1198                        vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1199
1200                    gate_proj_vec.push(gate_proj);
1201                    up_proj_vec.push(up_proj);
1202                    down_proj_vec.push(down_proj);
1203                }
1204
1205                let mut gate_proj: Arc<dyn QuantMethod> =
1206                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1207                        Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1208                    ))?);
1209                let mut up_proj: Arc<dyn QuantMethod> =
1210                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1211                        Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
1212                    ))?);
1213                let mut down_proj: Arc<dyn QuantMethod> =
1214                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1215                        Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1216                    ))?);
1217                gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1218                up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1219                down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1220
1221                (gate_proj, up_proj, down_proj)
1222            };
1223
1224        Ok(Self {
1225            fused_gate_proj,
1226            fused_up_proj,
1227            fused_down_proj,
1228        })
1229    }
1230}
1231
1232/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
1233pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1234    if comm.world_size() == 1 {
1235        return Shard::default();
1236    }
1237
1238    // Tensor parallelism case
1239
1240    // We may need to replicate the kv heads
1241    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1242        comm.world_size() / total_num_kv_heads
1243    } else {
1244        return Shard::Simple {
1245            dim: 0,
1246            rank: comm.rank(),
1247            world_size: comm.world_size(),
1248        };
1249    };
1250
1251    let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1252    let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1253    Shard::Offset {
1254        dim: 0,
1255        offset: kv_shard_id * head_dim,
1256        len: head_dim,
1257    }
1258}
1259
1260/// Compute the number of KV groups, taking into account KV head replication.
1261pub fn compute_n_kv_groups(
1262    total_num_kv_heads: usize,
1263    num_attention_heads: usize,
1264    comm: &Comm,
1265) -> usize {
1266    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1267        comm.world_size() / total_num_kv_heads
1268    } else {
1269        1
1270    };
1271    if kv_replicate != 0 {
1272        (num_attention_heads / total_num_kv_heads) / kv_replicate
1273    } else {
1274        num_attention_heads / total_num_kv_heads
1275    }
1276}