mistralrs_quant/distributed/
layers.rs

1use std::sync::Arc;
2
3use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
4use candle_nn::Linear;
5
6use crate::{
7    blockwise_fp8::{blockwise_fp8_linear_b, blockwise_fp8_moe},
8    distributed,
9    gptq::gptq_linear,
10    lora::merge_lora_weights,
11    pertensor_fp8::pertensor_fp8_linear_b,
12    should_apply_immediate_isq,
13    utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
14    AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, MXFP4Layer,
15    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
16    QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
17};
18
19use super::{Comm, SumAllReduce};
20
21fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
22    Shard::Simple {
23        dim,
24        rank,
25        world_size,
26    }
27}
28
29/// This layer has a weight that is parallelized along the input dimension,
30/// returning the "full" output dimension.
31#[derive(Debug)]
32pub struct RowParallelLayer {
33    weight: Arc<dyn QuantMethod>,
34    bias: Option<Tensor>,
35    all_reduce: distributed::SumAllReduce,
36}
37
38impl RowParallelLayer {
39    #[allow(clippy::new_ret_no_self)]
40    pub fn new(
41        in_dim: usize,
42        out_dim: usize,
43        config: &Option<QuantizedConfig>,
44        bias: bool,
45        comm: &Arc<crate::Comm>,
46        vb: ShardedVarBuilder,
47    ) -> Result<Arc<dyn QuantMethod>> {
48        let rank = comm.rank();
49        let world_size = comm.world_size();
50        let shard = shard(1, rank, world_size);
51
52        let base_vb = vb.clone();
53        let vb = if should_apply_immediate_isq(&vb) {
54            vb.set_device(Device::Cpu)
55        } else {
56            vb
57        };
58
59        let weight = if let Some(quant_conf) = &config {
60            // GPTQ and BNB do not support tensor parallelism
61            if matches!(
62                quant_conf,
63                QuantizedConfig::GptqAwq { .. }
64                    | QuantizedConfig::Bitsandbytes { .. }
65                    | QuantizedConfig::Afq { .. }
66            ) && comm.world_size() != 1
67            {
68                candle_core::bail!(
69                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
70                    comm.world_size()
71                );
72            }
73
74            match quant_conf {
75                QuantizedConfig::GptqAwq { .. } => {
76                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
77                }
78                QuantizedConfig::Fp8 { weight_block_size } => {
79                    // NOTE: no bias for fp8 as it might be parallelized
80                    if weight_block_size.is_some() {
81                        blockwise_fp8_linear_b(
82                            in_dim,
83                            out_dim,
84                            quant_conf,
85                            false,
86                            shard,
87                            vb.clone(),
88                        )?
89                    } else {
90                        pertensor_fp8_linear_b(
91                            in_dim,
92                            out_dim,
93                            quant_conf,
94                            false,
95                            shard,
96                            vb.clone(),
97                        )?
98                    }
99                }
100                QuantizedConfig::Bitsandbytes { .. } => {
101                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
102                }
103                QuantizedConfig::Afq { .. } => {
104                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
105                }
106                QuantizedConfig::MXFP4 {} => {
107                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
108                }
109            }
110        } else {
111            // Handle the case where the layer is dummy (no tensors)
112            if !vb.contains_tensor("weight") {
113                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
114                Arc::new(layer) as Arc<dyn QuantMethod>
115            } else {
116                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
117                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
118
119                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
120                    Linear::new(weight, None),
121                ))?;
122                Arc::new(layer) as Arc<dyn QuantMethod>
123            }
124        };
125
126        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
127        let bias = if bias && vb.contains_tensor("bias") {
128            Some(vb.get((out_dim,), "bias")?)
129        } else {
130            None
131        };
132
133        let this_unquant = Arc::new(Self {
134            weight,
135            bias,
136            all_reduce: distributed::SumAllReduce::new(comm),
137        });
138        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
139        Ok(this)
140    }
141
142    #[allow(clippy::new_ret_no_self)]
143    pub fn new_matformer(
144        in_dim: usize,
145        out_dim: usize,
146        orig_intermediate_size: usize,
147        config: &Option<QuantizedConfig>,
148        bias: bool,
149        comm: &Arc<crate::Comm>,
150        vb: ShardedVarBuilder,
151    ) -> Result<Arc<dyn QuantMethod>> {
152        let rank = comm.rank();
153        let world_size = comm.world_size();
154        let shard = shard(1, rank, world_size);
155
156        let base_vb = vb.clone();
157        let vb = if should_apply_immediate_isq(&vb) {
158            vb.set_device(Device::Cpu)
159        } else {
160            vb
161        };
162
163        if config.is_some() {
164            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
165        }
166
167        // Handle the case where the layer is dummy (no tensors)
168        let weight = if !vb.contains_tensor("weight") {
169            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
170            Arc::new(layer) as Arc<dyn QuantMethod>
171        } else {
172            let weight = vb
173                .get_with_hints(
174                    (out_dim, orig_intermediate_size),
175                    "weight",
176                    Default::default(),
177                )?
178                .i((.., ..in_dim))?
179                .contiguous()?;
180
181            let weight = shard.apply_to(&weight)?;
182            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
183
184            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
185                Linear::new(weight, None),
186            ))?;
187            Arc::new(layer) as Arc<dyn QuantMethod>
188        };
189
190        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
191        let bias = if bias && vb.contains_tensor("bias") {
192            Some(vb.get((out_dim,), "bias")?)
193        } else {
194            None
195        };
196
197        let this_unquant = Arc::new(Self {
198            weight,
199            bias,
200            all_reduce: distributed::SumAllReduce::new(comm),
201        });
202        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
203        Ok(this)
204    }
205}
206
207impl QuantMethod for RowParallelLayer {
208    fn new(_method: QuantMethodConfig) -> Result<Self>
209    where
210        Self: Sized,
211    {
212        candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
213    }
214
215    fn forward(&self, a: &Tensor) -> Result<Tensor> {
216        let mut xs = self.weight.forward(a)?;
217        xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
218        if let Some(bias) = &self.bias {
219            xs = xs.broadcast_add(bias)?;
220        }
221        Ok(xs)
222    }
223
224    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
225        let weight = self.weight.add_delta_w(delta)?;
226        Ok(Arc::new(Self {
227            weight,
228            bias: self.bias.clone(),
229            all_reduce: self.all_reduce.clone(),
230        }))
231    }
232
233    fn dequantize_w(&self) -> Result<Tensor> {
234        self.weight.dequantize_w()
235    }
236
237    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
238        self.weight.dtype_and_device()
239    }
240
241    fn begin_track_stats(&mut self) -> Result<()> {
242        Arc::get_mut(&mut self.weight)
243            .context("Failed to get &mut to weight")?
244            .begin_track_stats()
245    }
246
247    fn end_track_stats(&self) -> Result<Tensor> {
248        self.weight.end_track_stats()
249    }
250
251    fn quantized_act_type(&self) -> Option<candle_core::DType> {
252        self.weight.quantized_act_type()
253    }
254
255    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
256        self.weight.unquant_weight_bias()
257    }
258
259    fn apply_isq(
260        self: Arc<Self>,
261        dtype: Option<crate::IsqType>,
262        device: candle_core::Device,
263        n_quantized: &std::sync::atomic::AtomicUsize,
264        imatrix_weight: Option<Vec<f32>>,
265        guard: QuantizeOntoGuard,
266    ) -> Result<Arc<dyn QuantMethod>> {
267        let weight =
268            self.weight
269                .clone()
270                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
271        let bias = match &self.bias {
272            Some(b) => {
273                let (dtype, device) = weight.dtype_and_device();
274                Some(b.to_device(&device)?.to_dtype(dtype)?)
275            }
276            None => None,
277        };
278        Ok(Arc::new(Self {
279            weight,
280            bias,
281            all_reduce: self.all_reduce.clone(),
282        }))
283    }
284
285    fn is_distributed(&self) -> Option<DistributedKind> {
286        Some(DistributedKind::RowParallel)
287    }
288}
289
290impl QuantizedSerde for RowParallelLayer {
291    fn isq_serde_supported(&self) -> bool {
292        self.weight.isq_serde_supported()
293    }
294    fn name(&self) -> &'static str {
295        self.weight.name()
296    }
297    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
298        self.weight.serialize_with_bias(self.bias.clone())
299    }
300    fn deserialize(
301        data: std::borrow::Cow<[u8]>,
302        device: &candle_core::Device,
303        comm: &Arc<crate::Comm>,
304        guard: QuantizeOntoGuard,
305    ) -> Result<Arc<dyn QuantMethod>>
306    where
307        Self: Sized,
308    {
309        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
310        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
311        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
312            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
313            QuantizedSerdeType::Unquant => {
314                UnquantLinear::deserialize_ext_bias(data, device, guard)?
315            }
316            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
317            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
318            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
319        };
320        Ok(Arc::new(Self {
321            weight,
322            bias,
323            all_reduce: SumAllReduce::new(comm),
324        }))
325    }
326}
327
328#[derive(Debug)]
329/// This layer has a weight that is parallelized along the output dimension,
330/// taking the "full" input dimension.
331pub struct ColumnParallelLayer {
332    weight: Arc<dyn QuantMethod>,
333    bias: Option<Tensor>,
334}
335
336impl ColumnParallelLayer {
337    #[allow(clippy::new_ret_no_self)]
338    pub fn new_with_shard(
339        in_dim: usize,
340        out_dim: usize,
341        config: &Option<QuantizedConfig>,
342        bias: bool,
343        comm: &Arc<crate::Comm>,
344        shard: Shard,
345        vb: ShardedVarBuilder,
346    ) -> Result<Arc<dyn QuantMethod>> {
347        let base_vb = vb.clone();
348        let vb = if should_apply_immediate_isq(&vb) {
349            vb.set_device(Device::Cpu)
350        } else {
351            vb
352        };
353
354        let weight = if let Some(quant_conf) = &config {
355            // GPTQ and BNB do not support tensor parallelism
356            if matches!(
357                quant_conf,
358                QuantizedConfig::GptqAwq { .. }
359                    | QuantizedConfig::Bitsandbytes { .. }
360                    | QuantizedConfig::Afq { .. }
361            ) && comm.world_size() != 1
362            {
363                candle_core::bail!(
364                    "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
365                    comm.world_size()
366                );
367            }
368
369            match quant_conf {
370                QuantizedConfig::GptqAwq { .. } => {
371                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
372                }
373                QuantizedConfig::Fp8 { weight_block_size } => {
374                    // NOTE: no bias for fp8 as it might be parallelized
375                    if weight_block_size.is_some() {
376                        blockwise_fp8_linear_b(
377                            in_dim,
378                            out_dim,
379                            quant_conf,
380                            false,
381                            shard,
382                            vb.clone(),
383                        )?
384                    } else {
385                        pertensor_fp8_linear_b(
386                            in_dim,
387                            out_dim,
388                            quant_conf,
389                            false,
390                            shard,
391                            vb.clone(),
392                        )?
393                    }
394                }
395                QuantizedConfig::Bitsandbytes { .. } => {
396                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
397                }
398                QuantizedConfig::Afq { .. } => {
399                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
400                }
401                QuantizedConfig::MXFP4 {} => {
402                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
403                }
404            }
405        } else {
406            // Handle the case where the layer is dummy (no tensors)
407            if !vb.contains_tensor("weight") {
408                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
409                Arc::new(layer) as Arc<dyn QuantMethod>
410            } else {
411                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
412                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
413
414                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
415                    Linear::new(weight, None),
416                ))?;
417                Arc::new(layer) as Arc<dyn QuantMethod>
418            }
419        };
420
421        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
422        let bias = if bias && vb.contains_tensor("bias") {
423            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
424        } else {
425            None
426        };
427
428        let this_unquant = Arc::new(Self { weight, bias });
429        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
430        Ok(this)
431    }
432
433    #[allow(clippy::new_ret_no_self)]
434    pub fn new(
435        in_dim: usize,
436        out_dim: usize,
437        config: &Option<QuantizedConfig>,
438        bias: bool,
439        comm: &Arc<crate::Comm>,
440        vb: ShardedVarBuilder,
441    ) -> Result<Arc<dyn QuantMethod>> {
442        let rank = comm.rank();
443        let world_size = comm.world_size();
444        let shard = shard(0, rank, world_size);
445
446        Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
447    }
448
449    #[allow(clippy::new_ret_no_self)]
450    pub fn new_matformer(
451        in_dim: usize,
452        out_dim: usize,
453        orig_intermediate_size: usize,
454        config: &Option<QuantizedConfig>,
455        bias: bool,
456        comm: &Arc<crate::Comm>,
457        vb: ShardedVarBuilder,
458    ) -> Result<Arc<dyn QuantMethod>> {
459        let rank = comm.rank();
460        let world_size = comm.world_size();
461        let shard = shard(0, rank, world_size);
462
463        let base_vb = vb.clone();
464        let vb = if should_apply_immediate_isq(&vb) {
465            vb.set_device(Device::Cpu)
466        } else {
467            vb
468        };
469
470        if config.is_some() {
471            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
472        }
473
474        // Handle the case where the layer is dummy (no tensors)
475        let weight = if !vb.contains_tensor("weight") {
476            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
477            Arc::new(layer) as Arc<dyn QuantMethod>
478        } else {
479            let weight = vb
480                .get_with_hints(
481                    (orig_intermediate_size, in_dim),
482                    "weight",
483                    Default::default(),
484                )?
485                .i((..out_dim, ..))?
486                .contiguous()?;
487
488            let weight = shard.apply_to(&weight)?;
489            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
490
491            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
492                Linear::new(weight, None),
493            ))?;
494            Arc::new(layer) as Arc<dyn QuantMethod>
495        };
496
497        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
498        let bias = if bias && vb.contains_tensor("bias") {
499            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
500        } else {
501            None
502        };
503
504        let this_unquant = Arc::new(Self { weight, bias });
505        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
506        Ok(this)
507    }
508
509    pub fn new_merged(
510        in_dim: usize,
511        out_dim: usize,
512        chunks: usize,
513        config: &Option<QuantizedConfig>,
514        bias: bool,
515        comm: &Arc<crate::Comm>,
516        vb: ShardedVarBuilder,
517    ) -> Result<Vec<Arc<dyn QuantMethod>>> {
518        let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
519        for chunk_idx in 0..chunks {
520            let layer = ColumnParallelLayer::new_with_shard(
521                in_dim,
522                out_dim,
523                config,
524                bias,
525                comm,
526                shard(
527                    0,
528                    chunk_idx * comm.world_size() + comm.rank(),
529                    chunks * comm.world_size(),
530                ),
531                vb.clone(),
532            )?;
533            vec_layers.push(layer);
534        }
535        Ok(vec_layers)
536    }
537}
538
539impl QuantMethod for ColumnParallelLayer {
540    fn new(_method: QuantMethodConfig) -> Result<Self>
541    where
542        Self: Sized,
543    {
544        candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
545    }
546
547    fn forward(&self, a: &Tensor) -> Result<Tensor> {
548        let mut xs = self.weight.forward(a)?;
549        if let Some(bias) = &self.bias {
550            xs = xs.broadcast_add(bias)?;
551        }
552        Ok(xs)
553    }
554
555    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
556        let weight = self.weight.add_delta_w(delta)?;
557        Ok(Arc::new(Self {
558            weight,
559            bias: self.bias.clone(),
560        }))
561    }
562
563    fn dequantize_w(&self) -> Result<Tensor> {
564        self.weight.dequantize_w()
565    }
566
567    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
568        self.weight.dtype_and_device()
569    }
570
571    fn begin_track_stats(&mut self) -> Result<()> {
572        Arc::get_mut(&mut self.weight)
573            .context("Failed to get &mut to weight")?
574            .begin_track_stats()
575    }
576
577    fn end_track_stats(&self) -> Result<Tensor> {
578        self.weight.end_track_stats()
579    }
580
581    fn quantized_act_type(&self) -> Option<candle_core::DType> {
582        self.weight.quantized_act_type()
583    }
584
585    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
586        self.weight.unquant_weight_bias()
587    }
588
589    fn apply_isq(
590        self: Arc<Self>,
591        dtype: Option<crate::IsqType>,
592        device: candle_core::Device,
593        n_quantized: &std::sync::atomic::AtomicUsize,
594        imatrix_weight: Option<Vec<f32>>,
595        guard: QuantizeOntoGuard,
596    ) -> Result<Arc<dyn QuantMethod>> {
597        let weight =
598            self.weight
599                .clone()
600                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
601        let bias = match &self.bias {
602            Some(b) => {
603                let (dtype, device) = weight.dtype_and_device();
604                Some(b.to_device(&device)?.to_dtype(dtype)?)
605            }
606            None => None,
607        };
608        Ok(Arc::new(Self { weight, bias }))
609    }
610
611    fn is_distributed(&self) -> Option<DistributedKind> {
612        Some(DistributedKind::ColumnParallel)
613    }
614}
615
616impl QuantizedSerde for ColumnParallelLayer {
617    fn isq_serde_supported(&self) -> bool {
618        self.weight.isq_serde_supported()
619    }
620    fn name(&self) -> &'static str {
621        self.weight.name()
622    }
623    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
624        self.weight.serialize_with_bias(self.bias.clone())
625    }
626    fn deserialize(
627        data: std::borrow::Cow<[u8]>,
628        device: &candle_core::Device,
629        _comm: &Arc<crate::Comm>,
630        guard: QuantizeOntoGuard,
631    ) -> Result<Arc<dyn QuantMethod>>
632    where
633        Self: Sized,
634    {
635        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
636        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
637        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
638            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
639            QuantizedSerdeType::Unquant => {
640                UnquantLinear::deserialize_ext_bias(data, device, guard)?
641            }
642            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
643            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
644            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
645        };
646        Ok(Arc::new(Self { weight, bias }))
647    }
648}
649
650#[derive(Debug)]
651/// This layer has no parallelization
652pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
653
654impl ReplicatedLayer {
655    pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
656        let dev = lin.weight().device().clone();
657        let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
658        let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
659        Ok(this)
660    }
661
662    #[allow(clippy::new_ret_no_self)]
663    pub fn new(
664        in_dim: usize,
665        out_dim: usize,
666        config: &Option<QuantizedConfig>,
667        bias: bool,
668        vb: ShardedVarBuilder,
669    ) -> Result<Arc<dyn QuantMethod>> {
670        let base_vb = vb.clone();
671        let vb = if should_apply_immediate_isq(&vb) {
672            vb.set_device(Device::Cpu)
673        } else {
674            vb
675        };
676
677        let layer = if let Some(quant_conf) = &config {
678            match quant_conf {
679                QuantizedConfig::GptqAwq { .. } => {
680                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
681                }
682                QuantizedConfig::Fp8 { weight_block_size } => {
683                    if weight_block_size.is_some() {
684                        blockwise_fp8_linear_b(
685                            in_dim,
686                            out_dim,
687                            quant_conf,
688                            bias,
689                            Default::default(),
690                            vb.clone(),
691                        )?
692                    } else {
693                        pertensor_fp8_linear_b(
694                            in_dim,
695                            out_dim,
696                            quant_conf,
697                            bias,
698                            Default::default(),
699                            vb.clone(),
700                        )?
701                    }
702                }
703                QuantizedConfig::Bitsandbytes { .. } => {
704                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
705                }
706                QuantizedConfig::Afq { .. } => {
707                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
708                }
709                QuantizedConfig::MXFP4 {} => {
710                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
711                }
712            }
713        } else {
714            // Handle the case where the layer is dummy (no tensors)
715            if !vb.contains_tensor("weight") {
716                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
717                Arc::new(layer) as Arc<dyn QuantMethod>
718            } else {
719                let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
720                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
721
722                let bias = if bias {
723                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
724                } else {
725                    None
726                };
727                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
728                    Linear::new(weight, bias),
729                ))?;
730                Arc::new(layer) as Arc<dyn QuantMethod>
731            }
732        };
733
734        let this_unquant = Arc::new(Self(layer));
735        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
736        Ok(this)
737    }
738
739    #[allow(clippy::new_ret_no_self)]
740    pub fn new_layers_matformer_indices(
741        in_dim: usize,
742        out_dim: usize,
743        kept_layers_indices: Option<&Tensor>,
744        orig_num_hidden_layers: usize,
745        config: &Option<QuantizedConfig>,
746        bias: bool,
747        vb: ShardedVarBuilder,
748    ) -> Result<Arc<dyn QuantMethod>> {
749        let base_vb = vb.clone();
750        let vb = if should_apply_immediate_isq(&vb) {
751            vb.set_device(Device::Cpu)
752        } else {
753            vb
754        };
755
756        let layer = if let Some(quant_conf) = &config {
757            if kept_layers_indices.is_some() {
758                candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
759            }
760
761            match quant_conf {
762                QuantizedConfig::GptqAwq { .. } => {
763                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
764                }
765                QuantizedConfig::Fp8 { weight_block_size } => {
766                    if weight_block_size.is_some() {
767                        blockwise_fp8_linear_b(
768                            in_dim,
769                            out_dim,
770                            quant_conf,
771                            bias,
772                            Default::default(),
773                            vb.clone(),
774                        )?
775                    } else {
776                        pertensor_fp8_linear_b(
777                            in_dim,
778                            out_dim,
779                            quant_conf,
780                            bias,
781                            Default::default(),
782                            vb.clone(),
783                        )?
784                    }
785                }
786                QuantizedConfig::Bitsandbytes { .. } => {
787                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
788                }
789                QuantizedConfig::Afq { .. } => {
790                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
791                }
792                QuantizedConfig::MXFP4 {} => {
793                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
794                }
795            }
796        } else {
797            // Handle the case where the layer is dummy (no tensors)
798            if !vb.contains_tensor("weight") {
799                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
800                Arc::new(layer) as Arc<dyn QuantMethod>
801            } else {
802                let mut weight =
803                    vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
804
805                if let Some(kept_layers_indices) = &kept_layers_indices {
806                    let weight_reshaped = weight.reshape((
807                        orig_num_hidden_layers,
808                        weight.dim(0)? / orig_num_hidden_layers,
809                        weight.dim(1)?,
810                    ))?;
811
812                    weight = weight_reshaped
813                        .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
814                        .reshape(((), weight_reshaped.dim(D::Minus1)?))?
815                        .contiguous()?;
816                }
817
818                weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
819
820                let bias = if bias {
821                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
822                } else {
823                    None
824                };
825                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
826                    Linear::new(weight, bias),
827                ))?;
828                Arc::new(layer) as Arc<dyn QuantMethod>
829            }
830        };
831
832        let this_unquant = Arc::new(Self(layer));
833        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
834        Ok(this)
835    }
836}
837
838impl QuantMethod for ReplicatedLayer {
839    fn new(_method: QuantMethodConfig) -> Result<Self>
840    where
841        Self: Sized,
842    {
843        candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
844    }
845
846    fn forward(&self, a: &Tensor) -> Result<Tensor> {
847        self.0.forward(a)
848    }
849
850    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
851        self.0.add_delta_w(delta)
852    }
853
854    fn dequantize_w(&self) -> Result<Tensor> {
855        self.0.dequantize_w()
856    }
857
858    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
859        self.0.dtype_and_device()
860    }
861
862    fn begin_track_stats(&mut self) -> Result<()> {
863        Arc::get_mut(&mut self.0)
864            .context("Failed to get &mut to weight")?
865            .begin_track_stats()
866    }
867
868    fn end_track_stats(&self) -> Result<Tensor> {
869        self.0.end_track_stats()
870    }
871
872    fn quantized_act_type(&self) -> Option<candle_core::DType> {
873        self.0.quantized_act_type()
874    }
875
876    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
877        self.0.unquant_weight_bias()
878    }
879
880    fn apply_isq(
881        self: Arc<Self>,
882        dtype: Option<crate::IsqType>,
883        device: candle_core::Device,
884        n_quantized: &std::sync::atomic::AtomicUsize,
885        imatrix_weight: Option<Vec<f32>>,
886        guard: QuantizeOntoGuard,
887    ) -> Result<Arc<dyn QuantMethod>> {
888        self.0
889            .clone()
890            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
891    }
892
893    fn is_distributed(&self) -> Option<DistributedKind> {
894        Some(DistributedKind::Replicated)
895    }
896}
897
898impl QuantizedSerde for ReplicatedLayer {
899    fn isq_serde_supported(&self) -> bool {
900        self.0.isq_serde_supported()
901    }
902    fn name(&self) -> &'static str {
903        self.0.name()
904    }
905    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
906        self.0.serialize()
907    }
908    fn deserialize(
909        data: std::borrow::Cow<[u8]>,
910        device: &candle_core::Device,
911        comm: &Arc<crate::Comm>,
912        guard: QuantizeOntoGuard,
913    ) -> Result<Arc<dyn QuantMethod>>
914    where
915        Self: Sized,
916    {
917        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
918        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
919        let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
920            QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
921            QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
922            QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
923            QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
924            QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
925        };
926        Ok(Arc::new(Self(deserialized)))
927    }
928}
929
930#[derive(Debug)]
931pub struct PackedExperts {
932    pub gate_proj: Vec<Arc<dyn QuantMethod>>,
933    pub up_proj: Vec<Arc<dyn QuantMethod>>,
934    pub down_proj: Vec<Arc<dyn QuantMethod>>,
935}
936
937impl PackedExperts {
938    /// Note: we only support AFQ and unquantized here because they are the only ones that support indexed.
939    #[allow(clippy::too_many_arguments)]
940    pub fn new(
941        num_local_experts: usize,
942        hidden_size: usize,
943        intermediate_size: usize,
944        config: &Option<QuantizedConfig>,
945        bias: bool,
946        comm: &Arc<crate::Comm>,
947        vb: ShardedVarBuilder,
948    ) -> Result<Self> {
949        if bias {
950            candle_core::bail!("PackedExperts does not support bias.");
951        }
952
953        let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
954            // GPTQ and BNB do not support tensor parallelism
955            if comm.world_size() != 1 {
956                candle_core::bail!(
957                    "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
958                    comm.world_size()
959                );
960            }
961
962            match quant_conf {
963                QuantizedConfig::Afq { .. } => {
964                    if !vb.contains_tensor("gate_up_proj")
965                        || !vb.contains_tensor("gate_up_proj.weight")
966                    {
967                        candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
968                    }
969
970                    let base_vb = vb.clone();
971
972                    let vb_gate_proj = if should_apply_immediate_isq(&vb) {
973                        vb.pp("gate_proj").set_device(Device::Cpu)
974                    } else {
975                        vb.pp("gate_proj")
976                    };
977                    let vb_up_proj = if should_apply_immediate_isq(&vb) {
978                        vb.pp("up_proj").set_device(Device::Cpu)
979                    } else {
980                        vb.pp("up_proj")
981                    };
982                    let vb_down_proj = if should_apply_immediate_isq(&vb) {
983                        vb.pp("down_proj").set_device(Device::Cpu)
984                    } else {
985                        vb.pp("down_proj")
986                    };
987                    let mut gate_proj = AfqLayer::afq_packed_linear_b(
988                        num_local_experts,
989                        hidden_size,
990                        intermediate_size,
991                        quant_conf,
992                        bias,
993                        vb_gate_proj,
994                    )?;
995                    let mut up_proj = AfqLayer::afq_packed_linear_b(
996                        num_local_experts,
997                        hidden_size,
998                        intermediate_size,
999                        quant_conf,
1000                        bias,
1001                        vb_up_proj,
1002                    )?;
1003                    let mut down_proj = AfqLayer::afq_packed_linear_b(
1004                        num_local_experts,
1005                        intermediate_size,
1006                        hidden_size,
1007                        quant_conf,
1008                        bias,
1009                        vb_down_proj,
1010                    )?;
1011
1012                    gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
1013                    up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
1014                    down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
1015
1016                    (vec![gate_proj], vec![up_proj], vec![down_proj])
1017                }
1018                QuantizedConfig::Fp8 { weight_block_size } => {
1019                    // FP8 quantization for PackedExperts
1020                    // Keep weights as FP8 using BlockwiseFP8Linear to leverage native FP8 GEMM
1021                    let Some(weight_block_size) = weight_block_size else {
1022                        candle_core::bail!("Blockwise FP8 for PackedExperts requires weight_block_size to be set.")
1023                    };
1024                    if weight_block_size.len() != 2 {
1025                        candle_core::bail!(
1026                            "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1027                        );
1028                    }
1029
1030                    // Check if we have stacked format (gate_up_proj) or per-expert format
1031                    // Note: vb already has the "experts" prefix from the caller (experts.rs)
1032                    let is_stacked_format = vb.contains_tensor("gate_up_proj");
1033
1034                    if is_stacked_format {
1035                        // Stacked format: load FP8 tensors and split
1036                        let has_fp8_scales = vb.contains_tensor("gate_up_proj.weight_scale_inv");
1037
1038                        if has_fp8_scales {
1039                            // Load gate_up_proj FP8 tensor and scale
1040                            let gate_up_fp8 = vb.get_with_hints_dtype(
1041                                (num_local_experts, hidden_size, intermediate_size * 2),
1042                                "gate_up_proj",
1043                                Default::default(),
1044                                candle_core::DType::F8E4M3,
1045                            )?;
1046                            let gate_up_scale = vb.get_with_hints_dtype(
1047                                (
1048                                    num_local_experts,
1049                                    hidden_size.div_ceil(weight_block_size[0]),
1050                                    (intermediate_size * 2).div_ceil(weight_block_size[1]),
1051                                ),
1052                                "gate_up_proj.weight_scale_inv",
1053                                Default::default(),
1054                                candle_core::DType::F32,
1055                            )?;
1056
1057                            // Load down_proj FP8 tensor and scale
1058                            let down_fp8 = vb.get_with_hints_dtype(
1059                                (num_local_experts, intermediate_size, hidden_size),
1060                                "down_proj",
1061                                Default::default(),
1062                                candle_core::DType::F8E4M3,
1063                            )?;
1064                            let down_scale = vb.get_with_hints_dtype(
1065                                (
1066                                    num_local_experts,
1067                                    intermediate_size.div_ceil(weight_block_size[0]),
1068                                    hidden_size.div_ceil(weight_block_size[1]),
1069                                ),
1070                                "down_proj.weight_scale_inv",
1071                                Default::default(),
1072                                candle_core::DType::F32,
1073                            )?;
1074
1075                            // Split and create individual BlockwiseFP8Linear for each expert
1076                            let mut gs = Vec::new();
1077                            let mut us = Vec::new();
1078                            let mut ds = Vec::new();
1079
1080                            for i in 0..num_local_experts {
1081                                // Extract this expert's weights
1082                                let gate_up_expert =
1083                                    gate_up_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1084                                let gate_up_scale_expert = gate_up_scale.i(i)?.contiguous()?;
1085                                let down_expert = down_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1086                                let down_scale_expert = down_scale.i(i)?.contiguous()?;
1087
1088                                // Split gate_up into gate and up
1089                                let gate_expert = gate_up_expert.narrow(0, 0, intermediate_size)?;
1090                                let up_expert = gate_up_expert.narrow(
1091                                    0,
1092                                    intermediate_size,
1093                                    intermediate_size,
1094                                )?;
1095
1096                                // Split scales
1097                                let gate_scale_expert = gate_up_scale_expert.narrow(
1098                                    1,
1099                                    0,
1100                                    intermediate_size.div_ceil(weight_block_size[1]),
1101                                )?;
1102                                let up_scale_expert = gate_up_scale_expert.narrow(
1103                                    1,
1104                                    intermediate_size.div_ceil(weight_block_size[1]),
1105                                    intermediate_size.div_ceil(weight_block_size[1]),
1106                                )?;
1107
1108                                // Create BlockwiseFP8Linear for each projection
1109                                use crate::blockwise_fp8::BlockwiseFP8Linear;
1110                                use crate::QuantMethodConfig;
1111
1112                                let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1113                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1114                                        weight: gate_expert,
1115                                        weight_scale_inv: gate_scale_expert.transpose(0, 1)?,
1116                                        bias: None,
1117                                        dequant_dtype: vb.dtype(),
1118                                        weight_block_size: weight_block_size.clone(),
1119                                    })?,
1120                                );
1121                                let up_layer: Arc<dyn QuantMethod> = Arc::new(
1122                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1123                                        weight: up_expert,
1124                                        weight_scale_inv: up_scale_expert.transpose(0, 1)?,
1125                                        bias: None,
1126                                        dequant_dtype: vb.dtype(),
1127                                        weight_block_size: weight_block_size.clone(),
1128                                    })?,
1129                                );
1130                                let down_layer: Arc<dyn QuantMethod> = Arc::new(
1131                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1132                                        weight: down_expert,
1133                                        weight_scale_inv: down_scale_expert.transpose(0, 1)?,
1134                                        bias: None,
1135                                        dequant_dtype: vb.dtype(),
1136                                        weight_block_size: weight_block_size.clone(),
1137                                    })?,
1138                                );
1139
1140                                gs.push(gate_layer);
1141                                us.push(up_layer);
1142                                ds.push(down_layer);
1143                            }
1144
1145                            (gs, us, ds)
1146                        } else {
1147                            candle_core::bail!(
1148                                "PackedExperts with FP8 requires weight_scale_inv tensors"
1149                            );
1150                        }
1151                    } else {
1152                        // Per-expert format: load each expert individually
1153                        let mut gs = Vec::new();
1154                        let mut us = Vec::new();
1155                        let mut ds = Vec::new();
1156
1157                        for i in 0..num_local_experts {
1158                            let expert_vb = vb.pp(i);
1159
1160                            // Load FP8 weights and scales for each projection
1161                            let gate_fp8 = expert_vb.get_with_hints_dtype(
1162                                (intermediate_size, hidden_size),
1163                                "gate_proj.weight",
1164                                Default::default(),
1165                                candle_core::DType::F8E4M3,
1166                            )?;
1167                            let gate_scale = expert_vb.get_with_hints_dtype(
1168                                (
1169                                    intermediate_size.div_ceil(weight_block_size[0]),
1170                                    hidden_size.div_ceil(weight_block_size[1]),
1171                                ),
1172                                "gate_proj.weight_scale_inv",
1173                                Default::default(),
1174                                candle_core::DType::F32,
1175                            )?;
1176
1177                            let up_fp8 = expert_vb.get_with_hints_dtype(
1178                                (intermediate_size, hidden_size),
1179                                "up_proj.weight",
1180                                Default::default(),
1181                                candle_core::DType::F8E4M3,
1182                            )?;
1183                            let up_scale = expert_vb.get_with_hints_dtype(
1184                                (
1185                                    intermediate_size.div_ceil(weight_block_size[0]),
1186                                    hidden_size.div_ceil(weight_block_size[1]),
1187                                ),
1188                                "up_proj.weight_scale_inv",
1189                                Default::default(),
1190                                candle_core::DType::F32,
1191                            )?;
1192
1193                            let down_fp8 = expert_vb.get_with_hints_dtype(
1194                                (hidden_size, intermediate_size),
1195                                "down_proj.weight",
1196                                Default::default(),
1197                                candle_core::DType::F8E4M3,
1198                            )?;
1199                            let down_scale = expert_vb.get_with_hints_dtype(
1200                                (
1201                                    hidden_size.div_ceil(weight_block_size[0]),
1202                                    intermediate_size.div_ceil(weight_block_size[1]),
1203                                ),
1204                                "down_proj.weight_scale_inv",
1205                                Default::default(),
1206                                candle_core::DType::F32,
1207                            )?;
1208
1209                            // Create BlockwiseFP8Linear for each projection
1210                            use crate::blockwise_fp8::BlockwiseFP8Linear;
1211                            use crate::QuantMethodConfig;
1212
1213                            let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1214                                BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1215                                    weight: gate_fp8,
1216                                    weight_scale_inv: gate_scale,
1217                                    bias: None,
1218                                    dequant_dtype: vb.dtype(),
1219                                    weight_block_size: weight_block_size.clone(),
1220                                })?,
1221                            );
1222                            let up_layer: Arc<dyn QuantMethod> = Arc::new(BlockwiseFP8Linear::new(
1223                                QuantMethodConfig::BlockwiseFP8 {
1224                                    weight: up_fp8,
1225                                    weight_scale_inv: up_scale,
1226                                    bias: None,
1227                                    dequant_dtype: vb.dtype(),
1228                                    weight_block_size: weight_block_size.clone(),
1229                                },
1230                            )?);
1231                            let down_layer: Arc<dyn QuantMethod> = Arc::new(
1232                                BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1233                                    weight: down_fp8,
1234                                    weight_scale_inv: down_scale,
1235                                    bias: None,
1236                                    dequant_dtype: vb.dtype(),
1237                                    weight_block_size: weight_block_size.clone(),
1238                                })?,
1239                            );
1240
1241                            gs.push(gate_layer);
1242                            us.push(up_layer);
1243                            ds.push(down_layer);
1244                        }
1245
1246                        (gs, us, ds)
1247                    }
1248                }
1249                QuantizedConfig::MXFP4 {} => {
1250                    // MXFP4 quantization for PackedExperts
1251                    // Keep weights as MXFP4 using MXFP4Layer to leverage native MXFP4 GEMM
1252                    // Note: MXFP4 models use stacked format, so we load directly as packed experts
1253                    let gate_proj = MXFP4Layer::packed_linear_b(
1254                        num_local_experts,
1255                        hidden_size,
1256                        intermediate_size,
1257                        quant_conf,
1258                        bias,
1259                        vb.pp("gate_proj"),
1260                    )?;
1261                    let up_proj = MXFP4Layer::packed_linear_b(
1262                        num_local_experts,
1263                        hidden_size,
1264                        intermediate_size,
1265                        quant_conf,
1266                        bias,
1267                        vb.pp("up_proj"),
1268                    )?;
1269                    let down_proj = MXFP4Layer::packed_linear_b(
1270                        num_local_experts,
1271                        intermediate_size,
1272                        hidden_size,
1273                        quant_conf,
1274                        bias,
1275                        vb.pp("down_proj"),
1276                    )?;
1277
1278                    (vec![gate_proj], vec![up_proj], vec![down_proj])
1279                }
1280                _ => candle_core::bail!(
1281                    "PackedExperts with quantization config only allows AFQ, FP8, or MXFP4 quantization"
1282                ),
1283            }
1284        } else if !vb.contains_tensor("gate_up_proj") {
1285            // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
1286            let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
1287            let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
1288            let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
1289            for _ in 0..num_local_experts {
1290                gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1291                us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1292                ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1293            }
1294            (gs, us, ds)
1295        } else {
1296            // Parallelized like:
1297            // Each gpu holds all experts.
1298            // Gate/Up proj is parallelized on dim 2 (column)
1299            // Down proj is parallelized on dim 1 (row)
1300            // All reduce at the end.
1301
1302            // Handle the case where the layer is dummy (no tensors)
1303            let gate_up_block_size = intermediate_size / comm.world_size();
1304            let gate_up_start = gate_up_block_size * comm.rank();
1305
1306            // Gate is right before Up in the gate_up
1307            let shard_gate = Shard::Offset {
1308                dim: 2,
1309                offset: gate_up_start,
1310                len: gate_up_block_size,
1311            };
1312            let shard_up = Shard::Offset {
1313                dim: 2,
1314                offset: intermediate_size + gate_up_start,
1315                len: gate_up_block_size,
1316            };
1317            let shard_down = Shard::Simple {
1318                dim: 1,
1319                rank: comm.rank(),
1320                world_size: comm.world_size(),
1321            };
1322
1323            let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
1324                vb.pp("gate_up_proj").set_device(Device::Cpu)
1325            } else {
1326                vb.pp("gate_up_proj")
1327            };
1328            let vb_down_proj = if should_apply_immediate_isq(&vb) {
1329                vb.pp("down_proj").set_device(Device::Cpu)
1330            } else {
1331                vb.pp("down_proj")
1332            };
1333
1334            let gate_proj = vb
1335                .get_with_hints(
1336                    (num_local_experts, hidden_size, intermediate_size * 2),
1337                    "gate_up_proj",
1338                    shard_gate,
1339                )?
1340                .t()?
1341                .contiguous()?;
1342            let up_proj = vb
1343                .get_with_hints(
1344                    (num_local_experts, hidden_size, intermediate_size * 2),
1345                    "gate_up_proj",
1346                    shard_up,
1347                )?
1348                .t()?
1349                .contiguous()?;
1350            let down_proj = vb
1351                .get_with_hints(
1352                    (num_local_experts, intermediate_size, hidden_size),
1353                    "down_proj",
1354                    shard_down,
1355                )?
1356                .t()?
1357                .contiguous()?;
1358
1359            let gc = gate_proj.chunk(num_local_experts, 0)?;
1360            let uc = up_proj.chunk(num_local_experts, 0)?;
1361            let dc = down_proj.chunk(num_local_experts, 0)?;
1362            drop((gate_proj, up_proj, down_proj));
1363
1364            let mut gs = Vec::new();
1365            let mut us = Vec::new();
1366            let mut ds = Vec::new();
1367            for ((mut gate_proj, mut up_proj), mut down_proj) in
1368                gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
1369            {
1370                gate_proj = gate_proj.squeeze(0)?;
1371                up_proj = up_proj.squeeze(0)?;
1372                down_proj = down_proj.squeeze(0)?;
1373                let gate_proj = merge_lora_weights(
1374                    &vb,
1375                    gate_proj,
1376                    hidden_size,
1377                    intermediate_size * 2,
1378                    shard_gate,
1379                )?;
1380                let up_proj =
1381                    merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1382                let down_proj =
1383                    merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1384
1385                let mut gate_proj: Arc<dyn QuantMethod> =
1386                    Arc::new(<UnquantLinear as QuantMethod>::new(
1387                        QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1388                    )?);
1389                gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1390                let mut up_proj: Arc<dyn QuantMethod> =
1391                    Arc::new(<UnquantLinear as QuantMethod>::new(
1392                        QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1393                    )?);
1394                up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1395                let mut down_proj: Arc<dyn QuantMethod> =
1396                    Arc::new(<UnquantLinear as QuantMethod>::new(
1397                        QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1398                    )?);
1399                down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1400                gs.push(gate_proj);
1401                us.push(up_proj);
1402                ds.push(down_proj);
1403            }
1404            (gs, us, ds)
1405        };
1406
1407        Ok(Self {
1408            gate_proj,
1409            up_proj,
1410            down_proj,
1411        })
1412    }
1413}
1414
1415pub struct FusedExperts {
1416    pub fused_gate_proj: Arc<dyn QuantMethod>,
1417    pub fused_up_proj: Arc<dyn QuantMethod>,
1418    pub fused_down_proj: Arc<dyn QuantMethod>,
1419}
1420
1421impl FusedExperts {
1422    pub fn new(
1423        hidden_size: usize,
1424        moe_intermediate_size: usize,
1425        num_experts: usize,
1426        quantization_config: &Option<QuantizedConfig>,
1427        vb: ShardedVarBuilder,
1428    ) -> Result<Self> {
1429        // Detect if weights are in stacked format (e.g., Qwen3 VL MoE):
1430        // - experts.gate_up_proj: (num_experts, hidden_size, intermediate_size * 2)
1431        // - experts.down_proj: (num_experts, intermediate_size, hidden_size)
1432        // Or per-expert format (e.g., Qwen3 MoE):
1433        // - experts.{i}.gate_proj.weight, experts.{i}.up_proj.weight, experts.{i}.down_proj.weight
1434        let experts_vb = vb.pp("experts");
1435        let is_stacked_format = experts_vb.contains_tensor("gate_up_proj");
1436
1437        let (fused_gate_proj, fused_up_proj, fused_down_proj) = if matches!(
1438            &quantization_config,
1439            Some(QuantizedConfig::Afq { .. })
1440        ) {
1441            let quantization_config = quantization_config.as_ref().unwrap();
1442
1443            let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1444                num_experts,
1445                hidden_size,
1446                moe_intermediate_size,
1447                quantization_config,
1448                false,
1449                vb.pp("switch_mlp.gate_proj"),
1450            )?;
1451            let fused_up_proj = AfqLayer::afq_packed_linear_b(
1452                num_experts,
1453                hidden_size,
1454                moe_intermediate_size,
1455                quantization_config,
1456                false,
1457                vb.pp("switch_mlp.up_proj"),
1458            )?;
1459            let fused_down_proj = AfqLayer::afq_packed_linear_b(
1460                num_experts,
1461                moe_intermediate_size,
1462                hidden_size,
1463                quantization_config,
1464                false,
1465                vb.pp("switch_mlp.down_proj"),
1466            )?;
1467
1468            (fused_gate_proj, fused_up_proj, fused_down_proj)
1469        } else if is_stacked_format
1470            && matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. }))
1471        {
1472            // Stacked format with FP8 quantization
1473            // Keep weights as FP8 using BlockwiseFP8 to leverage native FP8 GEMM in gather_forward
1474            let has_fp8_scales = experts_vb.contains_tensor("gate_up_proj.weight_scale_inv");
1475
1476            if has_fp8_scales {
1477                let weight_block_size = match quantization_config {
1478                    Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1479                    _ => unreachable!(),
1480                };
1481
1482                let Some(weight_block_size) = weight_block_size else {
1483                    candle_core::bail!(
1484                        "Blockwise FP8 for stacked experts requires weight_block_size to be set."
1485                    )
1486                };
1487                if weight_block_size.len() != 2 {
1488                    candle_core::bail!(
1489                        "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1490                    );
1491                }
1492
1493                // Load gate_up_proj FP8 tensor and scale
1494                // Shape: [num_experts, hidden_size, intermediate_size * 2]
1495                let gate_up_fp8 = experts_vb.get_with_hints_dtype(
1496                    (num_experts, hidden_size, moe_intermediate_size * 2),
1497                    "gate_up_proj",
1498                    Default::default(),
1499                    candle_core::DType::F8E4M3,
1500                )?;
1501                let gate_up_scale = experts_vb.get_with_hints_dtype(
1502                    (
1503                        num_experts,
1504                        hidden_size.div_ceil(weight_block_size[0]),
1505                        (moe_intermediate_size * 2).div_ceil(weight_block_size[1]),
1506                    ),
1507                    "gate_up_proj.weight_scale_inv",
1508                    Default::default(),
1509                    candle_core::DType::F32,
1510                )?;
1511
1512                // Load down_proj FP8 tensor and scale
1513                // Shape: [num_experts, intermediate_size, hidden_size]
1514                let down_fp8 = experts_vb.get_with_hints_dtype(
1515                    (num_experts, moe_intermediate_size, hidden_size),
1516                    "down_proj",
1517                    Default::default(),
1518                    candle_core::DType::F8E4M3,
1519                )?;
1520                let down_scale = experts_vb.get_with_hints_dtype(
1521                    (
1522                        num_experts,
1523                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1524                        hidden_size.div_ceil(weight_block_size[1]),
1525                    ),
1526                    "down_proj.weight_scale_inv",
1527                    Default::default(),
1528                    candle_core::DType::F32,
1529                )?;
1530
1531                // Split gate_up into gate and up
1532                let gate_fp8 = gate_up_fp8.narrow(2, 0, moe_intermediate_size)?;
1533                let up_fp8 = gate_up_fp8.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1534
1535                // Split scales similarly
1536                let gate_scale = gate_up_scale.narrow(
1537                    2,
1538                    0,
1539                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1540                )?;
1541                let up_scale = gate_up_scale.narrow(
1542                    2,
1543                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1544                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1545                )?;
1546
1547                // Transpose to match expected format: [num_experts, N, K]
1548                // gate/up: [num_experts, hidden_size, intermediate_size] -> [num_experts, intermediate_size, hidden_size]
1549                let gate_fp8 = gate_fp8.transpose(1, 2)?.contiguous()?;
1550                let up_fp8 = up_fp8.transpose(1, 2)?.contiguous()?;
1551                // down: [num_experts, intermediate_size, hidden_size] -> [num_experts, hidden_size, intermediate_size]
1552                let down_fp8 = down_fp8.transpose(1, 2)?.contiguous()?;
1553
1554                // Transpose scales to match weight layout
1555                let gate_scale = gate_scale.transpose(1, 2)?.contiguous()?;
1556                let up_scale = up_scale.transpose(1, 2)?.contiguous()?;
1557                let down_scale = down_scale.transpose(1, 2)?.contiguous()?;
1558
1559                // Create BlockwiseFP8Linear for each projection
1560                let fused_gate_proj =
1561                    blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1562                let fused_up_proj =
1563                    blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1564                let fused_down_proj =
1565                    blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1566
1567                (fused_gate_proj, fused_up_proj, fused_down_proj)
1568            } else {
1569                // FP8 config but no scale tensors - weights are actually unquantized
1570                tracing::warn!(
1571                        "FP8 quantization config specified but no scale tensors found for stacked MoE experts. \
1572                        Loading as unquantized."
1573                    );
1574                let gate_up_proj = experts_vb.get(
1575                    (num_experts, hidden_size, moe_intermediate_size * 2),
1576                    "gate_up_proj",
1577                )?;
1578                let down_proj_packed = experts_vb.get(
1579                    (num_experts, moe_intermediate_size, hidden_size),
1580                    "down_proj",
1581                )?;
1582
1583                // Split gate_up_proj into gate_proj and up_proj along the last dimension
1584                let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1585                let up_proj =
1586                    gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1587
1588                // Transpose dims 1 and 2 to match GGUF format
1589                let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1590                let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1591                let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1592
1593                let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1594                    QuantMethodConfig::Unquantized(Linear::new(gate_proj.clone(), None)),
1595                )?);
1596                let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1597                    QuantMethodConfig::Unquantized(Linear::new(up_proj.clone(), None)),
1598                )?);
1599                let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1600                    QuantMethodConfig::Unquantized(Linear::new(down_proj.clone(), None)),
1601                )?);
1602                // Use apply_immediate_isq_always to ensure ISQ is applied to expert weights
1603                let device = gate_proj.device();
1604                fused_gate_proj = apply_immediate_isq_always(fused_gate_proj, device)?;
1605                fused_up_proj = apply_immediate_isq_always(fused_up_proj, device)?;
1606                fused_down_proj = apply_immediate_isq_always(fused_down_proj, device)?;
1607
1608                (fused_gate_proj, fused_up_proj, fused_down_proj)
1609            }
1610        } else if is_stacked_format
1611            && matches!(&quantization_config, Some(QuantizedConfig::MXFP4 {}))
1612        {
1613            // Stacked format with MXFP4 quantization
1614            // For MXFP4, weights are stored as packed FP4 (2 values per byte)
1615            // with E8M0 scales
1616            let quantization_config = quantization_config.as_ref().unwrap();
1617
1618            // Load MXFP4 packed experts using MXFP4Layer::packed_linear_b
1619            // The tensors are expected at:
1620            //   gate_proj.blocks: [num_experts, intermediate_size, hidden_size/2]
1621            //   gate_proj.scales: [num_experts, intermediate_size, hidden_size/32]
1622            let fused_gate_proj = MXFP4Layer::packed_linear_b(
1623                num_experts,
1624                hidden_size,
1625                moe_intermediate_size,
1626                quantization_config,
1627                false,
1628                experts_vb.pp("gate_proj"),
1629            )?;
1630            let fused_up_proj = MXFP4Layer::packed_linear_b(
1631                num_experts,
1632                hidden_size,
1633                moe_intermediate_size,
1634                quantization_config,
1635                false,
1636                experts_vb.pp("up_proj"),
1637            )?;
1638            let fused_down_proj = MXFP4Layer::packed_linear_b(
1639                num_experts,
1640                moe_intermediate_size,
1641                hidden_size,
1642                quantization_config,
1643                false,
1644                experts_vb.pp("down_proj"),
1645            )?;
1646
1647            (fused_gate_proj, fused_up_proj, fused_down_proj)
1648        } else if is_stacked_format {
1649            // Stacked format from safetensors:
1650            // - gate_up_proj: [num_experts, hidden_size, intermediate_size * 2] = [128, 2048, 1536]
1651            // - down_proj: [num_experts, intermediate_size, hidden_size] = [128, 768, 2048]
1652            //
1653            // GGUF/indexed_moe_forward expects:
1654            // - gate/up: [num_experts, intermediate_size, hidden_size] = [128, 768, 2048]
1655            // - down: [num_experts, hidden_size, intermediate_size] = [128, 2048, 768]
1656            let gate_up_proj = experts_vb.get(
1657                (num_experts, hidden_size, moe_intermediate_size * 2),
1658                "gate_up_proj",
1659            )?;
1660            let down_proj_packed = experts_vb.get(
1661                (num_experts, moe_intermediate_size, hidden_size),
1662                "down_proj",
1663            )?;
1664
1665            // Split gate_up_proj into gate_proj and up_proj along the last dimension
1666            // gate_proj: [num_experts, hidden_size, intermediate_size]
1667            // up_proj: [num_experts, hidden_size, intermediate_size]
1668            let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1669            let up_proj = gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1670
1671            // Transpose dims 1 and 2 to match GGUF format:
1672            // gate/up: [num_experts, hidden_size, intermediate_size] -> [num_experts, intermediate_size, hidden_size]
1673            let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1674            let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1675            // down_proj: [num_experts, intermediate_size, hidden_size] -> [num_experts, hidden_size, intermediate_size]
1676            let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1677
1678            let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1679                QuantMethodConfig::Unquantized(Linear::new(gate_proj.clone(), None)),
1680            )?);
1681            let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1682                QuantMethodConfig::Unquantized(Linear::new(up_proj.clone(), None)),
1683            )?);
1684            let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1685                QuantMethodConfig::Unquantized(Linear::new(down_proj.clone(), None)),
1686            )?);
1687            // Use apply_immediate_isq_always to ensure ISQ is applied to expert weights
1688            let device = gate_proj.device();
1689            fused_gate_proj = apply_immediate_isq_always(fused_gate_proj, device)?;
1690            fused_up_proj = apply_immediate_isq_always(fused_up_proj, device)?;
1691            fused_down_proj = apply_immediate_isq_always(fused_down_proj, device)?;
1692
1693            (fused_gate_proj, fused_up_proj, fused_down_proj)
1694        } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1695            // Per-expert format with FP8 quantization
1696            // Keep weights as FP8 using BlockwiseFP8 to leverage native FP8 GEMM in gather_forward
1697            let weight_block_size = match quantization_config {
1698                Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1699                _ => unreachable!(),
1700            };
1701
1702            let Some(weight_block_size) = weight_block_size else {
1703                candle_core::bail!(
1704                    "Blockwise FP8 for per-expert format requires weight_block_size to be set."
1705                )
1706            };
1707            if weight_block_size.len() != 2 {
1708                candle_core::bail!(
1709                    "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1710                );
1711            }
1712
1713            let mut gate_fp8_vec = Vec::new();
1714            let mut gate_scale_vec = Vec::new();
1715            let mut up_fp8_vec = Vec::new();
1716            let mut up_scale_vec = Vec::new();
1717            let mut down_fp8_vec = Vec::new();
1718            let mut down_scale_vec = Vec::new();
1719
1720            for i in 0..num_experts {
1721                let expert_vb = experts_vb.pp(i);
1722
1723                // Load FP8 weights and scales for each projection
1724                let gate_fp8 = expert_vb.get_with_hints_dtype(
1725                    (moe_intermediate_size, hidden_size),
1726                    "gate_proj.weight",
1727                    Default::default(),
1728                    candle_core::DType::F8E4M3,
1729                )?;
1730                let gate_scale = expert_vb.get_with_hints_dtype(
1731                    (
1732                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1733                        hidden_size.div_ceil(weight_block_size[1]),
1734                    ),
1735                    "gate_proj.weight_scale_inv",
1736                    Default::default(),
1737                    candle_core::DType::F32,
1738                )?;
1739
1740                let up_fp8 = expert_vb.get_with_hints_dtype(
1741                    (moe_intermediate_size, hidden_size),
1742                    "up_proj.weight",
1743                    Default::default(),
1744                    candle_core::DType::F8E4M3,
1745                )?;
1746                let up_scale = expert_vb.get_with_hints_dtype(
1747                    (
1748                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1749                        hidden_size.div_ceil(weight_block_size[1]),
1750                    ),
1751                    "up_proj.weight_scale_inv",
1752                    Default::default(),
1753                    candle_core::DType::F32,
1754                )?;
1755
1756                let down_fp8 = expert_vb.get_with_hints_dtype(
1757                    (hidden_size, moe_intermediate_size),
1758                    "down_proj.weight",
1759                    Default::default(),
1760                    candle_core::DType::F8E4M3,
1761                )?;
1762                let down_scale = expert_vb.get_with_hints_dtype(
1763                    (
1764                        hidden_size.div_ceil(weight_block_size[0]),
1765                        moe_intermediate_size.div_ceil(weight_block_size[1]),
1766                    ),
1767                    "down_proj.weight_scale_inv",
1768                    Default::default(),
1769                    candle_core::DType::F32,
1770                )?;
1771
1772                gate_fp8_vec.push(gate_fp8);
1773                gate_scale_vec.push(gate_scale);
1774                up_fp8_vec.push(up_fp8);
1775                up_scale_vec.push(up_scale);
1776                down_fp8_vec.push(down_fp8);
1777                down_scale_vec.push(down_scale);
1778            }
1779
1780            // Stack into [num_experts, N, K]
1781            let gate_fp8 = Tensor::stack(&gate_fp8_vec, 0)?;
1782            let gate_scale = Tensor::stack(&gate_scale_vec, 0)?;
1783            let up_fp8 = Tensor::stack(&up_fp8_vec, 0)?;
1784            let up_scale = Tensor::stack(&up_scale_vec, 0)?;
1785            let down_fp8 = Tensor::stack(&down_fp8_vec, 0)?;
1786            let down_scale = Tensor::stack(&down_scale_vec, 0)?;
1787
1788            // Create BlockwiseFP8Linear for each projection
1789            let fused_gate_proj =
1790                blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1791            let fused_up_proj =
1792                blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1793            let fused_down_proj =
1794                blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1795
1796            (fused_gate_proj, fused_up_proj, fused_down_proj)
1797        } else {
1798            // Per-expert format: load each expert individually and stack
1799            let mut gate_proj_vec = Vec::new();
1800            let mut up_proj_vec = Vec::new();
1801            let mut down_proj_vec = Vec::new();
1802            for i in 0..num_experts {
1803                let expert_vb = experts_vb.pp(i);
1804                let gate_proj =
1805                    expert_vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1806                let up_proj =
1807                    expert_vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1808                let down_proj =
1809                    expert_vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1810
1811                gate_proj_vec.push(gate_proj);
1812                up_proj_vec.push(up_proj);
1813                down_proj_vec.push(down_proj);
1814            }
1815
1816            let mut gate_proj: Arc<dyn QuantMethod> =
1817                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1818                    Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1819                ))?);
1820            let mut up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1821                QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&up_proj_vec, 0)?, None)),
1822            )?);
1823            let mut down_proj: Arc<dyn QuantMethod> =
1824                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1825                    Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1826                ))?);
1827            // Use experts.0.{proj} prefix to match the actual weight paths for ISQ predicate matching
1828            let expert0_vb = experts_vb.pp("0");
1829            gate_proj = apply_immediate_isq(gate_proj, expert0_vb.pp("gate_proj"))?;
1830            up_proj = apply_immediate_isq(up_proj, expert0_vb.pp("up_proj"))?;
1831            down_proj = apply_immediate_isq(down_proj, expert0_vb.pp("down_proj"))?;
1832
1833            (gate_proj, up_proj, down_proj)
1834        };
1835
1836        Ok(Self {
1837            fused_gate_proj,
1838            fused_up_proj,
1839            fused_down_proj,
1840        })
1841    }
1842}
1843
1844/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
1845pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1846    if comm.world_size() == 1 {
1847        return Shard::default();
1848    }
1849
1850    // Tensor parallelism case
1851
1852    // We may need to replicate the kv heads
1853    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1854        comm.world_size() / total_num_kv_heads
1855    } else {
1856        return Shard::Simple {
1857            dim: 0,
1858            rank: comm.rank(),
1859            world_size: comm.world_size(),
1860        };
1861    };
1862
1863    let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1864    let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1865    Shard::Offset {
1866        dim: 0,
1867        offset: kv_shard_id * head_dim,
1868        len: head_dim,
1869    }
1870}
1871
1872/// Compute the number of KV groups, taking into account KV head replication.
1873pub fn compute_n_kv_groups(
1874    total_num_kv_heads: usize,
1875    num_attention_heads: usize,
1876    comm: &Comm,
1877) -> usize {
1878    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1879        comm.world_size() / total_num_kv_heads
1880    } else {
1881        1
1882    };
1883    if kv_replicate != 0 {
1884        (num_attention_heads / total_num_kv_heads) / kv_replicate
1885    } else {
1886        num_attention_heads / total_num_kv_heads
1887    }
1888}