mistralrs_quant/distributed/
layers.rs

1use std::sync::Arc;
2
3use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
4use candle_nn::Linear;
5
6use crate::{
7    blockwise_fp8::{blockwise_fp8_linear_b, blockwise_fp8_moe},
8    distributed,
9    gptq::gptq_linear,
10    lora::merge_lora_weights,
11    should_apply_immediate_isq,
12    utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
13    AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, MXFP4Layer,
14    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
15    QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
16};
17
18use super::{Comm, SumAllReduce};
19
20fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
21    Shard::Simple {
22        dim,
23        rank,
24        world_size,
25    }
26}
27
28/// This layer has a weight that is parallelized along the input dimension,
29/// returning the "full" output dimension.
30#[derive(Debug)]
31pub struct RowParallelLayer {
32    weight: Arc<dyn QuantMethod>,
33    bias: Option<Tensor>,
34    all_reduce: distributed::SumAllReduce,
35}
36
37impl RowParallelLayer {
38    #[allow(clippy::new_ret_no_self)]
39    pub fn new(
40        in_dim: usize,
41        out_dim: usize,
42        config: &Option<QuantizedConfig>,
43        bias: bool,
44        comm: &Arc<crate::Comm>,
45        vb: ShardedVarBuilder,
46    ) -> Result<Arc<dyn QuantMethod>> {
47        let rank = comm.rank();
48        let world_size = comm.world_size();
49        let shard = shard(1, rank, world_size);
50
51        let base_vb = vb.clone();
52        let vb = if should_apply_immediate_isq(&vb) {
53            vb.set_device(Device::Cpu)
54        } else {
55            vb
56        };
57
58        let weight = if let Some(quant_conf) = &config {
59            // GPTQ and BNB do not support tensor parallelism
60            if matches!(
61                quant_conf,
62                QuantizedConfig::GptqAwq { .. }
63                    | QuantizedConfig::Bitsandbytes { .. }
64                    | QuantizedConfig::Afq { .. }
65            ) && comm.world_size() != 1
66            {
67                candle_core::bail!(
68                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
69                    comm.world_size()
70                );
71            }
72
73            match quant_conf {
74                QuantizedConfig::GptqAwq { .. } => {
75                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
76                }
77                QuantizedConfig::Fp8 { .. } => {
78                    // NOTE: no bias for fp8 as it might be parallelized
79                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
80                }
81                QuantizedConfig::Bitsandbytes { .. } => {
82                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
83                }
84                QuantizedConfig::Afq { .. } => {
85                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
86                }
87                QuantizedConfig::MXFP4 {} => {
88                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
89                }
90            }
91        } else {
92            // Handle the case where the layer is dummy (no tensors)
93            if !vb.contains_tensor("weight") {
94                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
95                Arc::new(layer) as Arc<dyn QuantMethod>
96            } else {
97                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
98                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
99
100                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
101                    Linear::new(weight, None),
102                ))?;
103                Arc::new(layer) as Arc<dyn QuantMethod>
104            }
105        };
106
107        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
108        let bias = if bias && vb.contains_tensor("bias") {
109            Some(vb.get((out_dim,), "bias")?)
110        } else {
111            None
112        };
113
114        let this_unquant = Arc::new(Self {
115            weight,
116            bias,
117            all_reduce: distributed::SumAllReduce::new(comm),
118        });
119        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
120        Ok(this)
121    }
122
123    #[allow(clippy::new_ret_no_self)]
124    pub fn new_matformer(
125        in_dim: usize,
126        out_dim: usize,
127        orig_intermediate_size: usize,
128        config: &Option<QuantizedConfig>,
129        bias: bool,
130        comm: &Arc<crate::Comm>,
131        vb: ShardedVarBuilder,
132    ) -> Result<Arc<dyn QuantMethod>> {
133        let rank = comm.rank();
134        let world_size = comm.world_size();
135        let shard = shard(1, rank, world_size);
136
137        let base_vb = vb.clone();
138        let vb = if should_apply_immediate_isq(&vb) {
139            vb.set_device(Device::Cpu)
140        } else {
141            vb
142        };
143
144        if config.is_some() {
145            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
146        }
147
148        // Handle the case where the layer is dummy (no tensors)
149        let weight = if !vb.contains_tensor("weight") {
150            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
151            Arc::new(layer) as Arc<dyn QuantMethod>
152        } else {
153            let weight = vb
154                .get_with_hints(
155                    (out_dim, orig_intermediate_size),
156                    "weight",
157                    Default::default(),
158                )?
159                .i((.., ..in_dim))?
160                .contiguous()?;
161
162            let weight = shard.apply_to(&weight)?;
163            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
164
165            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
166                Linear::new(weight, None),
167            ))?;
168            Arc::new(layer) as Arc<dyn QuantMethod>
169        };
170
171        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
172        let bias = if bias && vb.contains_tensor("bias") {
173            Some(vb.get((out_dim,), "bias")?)
174        } else {
175            None
176        };
177
178        let this_unquant = Arc::new(Self {
179            weight,
180            bias,
181            all_reduce: distributed::SumAllReduce::new(comm),
182        });
183        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
184        Ok(this)
185    }
186}
187
188impl QuantMethod for RowParallelLayer {
189    fn new(_method: QuantMethodConfig) -> Result<Self>
190    where
191        Self: Sized,
192    {
193        candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
194    }
195
196    fn forward(&self, a: &Tensor) -> Result<Tensor> {
197        let mut xs = self.weight.forward(a)?;
198        xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
199        if let Some(bias) = &self.bias {
200            xs = xs.broadcast_add(bias)?;
201        }
202        Ok(xs)
203    }
204
205    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
206        let weight = self.weight.add_delta_w(delta)?;
207        Ok(Arc::new(Self {
208            weight,
209            bias: self.bias.clone(),
210            all_reduce: self.all_reduce.clone(),
211        }))
212    }
213
214    fn dequantize_w(&self) -> Result<Tensor> {
215        self.weight.dequantize_w()
216    }
217
218    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
219        self.weight.dtype_and_device()
220    }
221
222    fn begin_track_stats(&mut self) -> Result<()> {
223        Arc::get_mut(&mut self.weight)
224            .context("Failed to get &mut to weight")?
225            .begin_track_stats()
226    }
227
228    fn end_track_stats(&self) -> Result<Tensor> {
229        self.weight.end_track_stats()
230    }
231
232    fn quantized_act_type(&self) -> Option<candle_core::DType> {
233        self.weight.quantized_act_type()
234    }
235
236    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
237        self.weight.unquant_weight_bias()
238    }
239
240    fn apply_isq(
241        self: Arc<Self>,
242        dtype: Option<crate::IsqType>,
243        device: candle_core::Device,
244        n_quantized: &std::sync::atomic::AtomicUsize,
245        imatrix_weight: Option<Vec<f32>>,
246        guard: QuantizeOntoGuard,
247    ) -> Result<Arc<dyn QuantMethod>> {
248        let weight =
249            self.weight
250                .clone()
251                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
252        let bias = match &self.bias {
253            Some(b) => {
254                let (dtype, device) = weight.dtype_and_device();
255                Some(b.to_device(&device)?.to_dtype(dtype)?)
256            }
257            None => None,
258        };
259        Ok(Arc::new(Self {
260            weight,
261            bias,
262            all_reduce: self.all_reduce.clone(),
263        }))
264    }
265
266    fn is_distributed(&self) -> Option<DistributedKind> {
267        Some(DistributedKind::RowParallel)
268    }
269}
270
271impl QuantizedSerde for RowParallelLayer {
272    fn isq_serde_supported(&self) -> bool {
273        self.weight.isq_serde_supported()
274    }
275    fn name(&self) -> &'static str {
276        self.weight.name()
277    }
278    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
279        self.weight.serialize_with_bias(self.bias.clone())
280    }
281    fn deserialize(
282        data: std::borrow::Cow<[u8]>,
283        device: &candle_core::Device,
284        comm: &Arc<crate::Comm>,
285        guard: QuantizeOntoGuard,
286    ) -> Result<Arc<dyn QuantMethod>>
287    where
288        Self: Sized,
289    {
290        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
291        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
292        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
293            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
294            QuantizedSerdeType::Unquant => {
295                UnquantLinear::deserialize_ext_bias(data, device, guard)?
296            }
297            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
298            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
299            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
300        };
301        Ok(Arc::new(Self {
302            weight,
303            bias,
304            all_reduce: SumAllReduce::new(comm),
305        }))
306    }
307}
308
309#[derive(Debug)]
310/// This layer has a weight that is parallelized along the output dimension,
311/// taking the "full" input dimension.
312pub struct ColumnParallelLayer {
313    weight: Arc<dyn QuantMethod>,
314    bias: Option<Tensor>,
315}
316
317impl ColumnParallelLayer {
318    #[allow(clippy::new_ret_no_self)]
319    pub fn new_with_shard(
320        in_dim: usize,
321        out_dim: usize,
322        config: &Option<QuantizedConfig>,
323        bias: bool,
324        comm: &Arc<crate::Comm>,
325        shard: Shard,
326        vb: ShardedVarBuilder,
327    ) -> Result<Arc<dyn QuantMethod>> {
328        let base_vb = vb.clone();
329        let vb = if should_apply_immediate_isq(&vb) {
330            vb.set_device(Device::Cpu)
331        } else {
332            vb
333        };
334
335        let weight = if let Some(quant_conf) = &config {
336            // GPTQ and BNB do not support tensor parallelism
337            if matches!(
338                quant_conf,
339                QuantizedConfig::GptqAwq { .. }
340                    | QuantizedConfig::Bitsandbytes { .. }
341                    | QuantizedConfig::Afq { .. }
342            ) && comm.world_size() != 1
343            {
344                candle_core::bail!(
345                    "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
346                    comm.world_size()
347                );
348            }
349
350            match quant_conf {
351                QuantizedConfig::GptqAwq { .. } => {
352                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
353                }
354                QuantizedConfig::Fp8 { .. } => {
355                    // NOTE: no bias for fp8 as it might be parallelized
356                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
357                }
358                QuantizedConfig::Bitsandbytes { .. } => {
359                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
360                }
361                QuantizedConfig::Afq { .. } => {
362                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
363                }
364                QuantizedConfig::MXFP4 {} => {
365                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
366                }
367            }
368        } else {
369            // Handle the case where the layer is dummy (no tensors)
370            if !vb.contains_tensor("weight") {
371                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
372                Arc::new(layer) as Arc<dyn QuantMethod>
373            } else {
374                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
375                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
376
377                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
378                    Linear::new(weight, None),
379                ))?;
380                Arc::new(layer) as Arc<dyn QuantMethod>
381            }
382        };
383
384        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
385        let bias = if bias && vb.contains_tensor("bias") {
386            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
387        } else {
388            None
389        };
390
391        let this_unquant = Arc::new(Self { weight, bias });
392        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
393        Ok(this)
394    }
395
396    #[allow(clippy::new_ret_no_self)]
397    pub fn new(
398        in_dim: usize,
399        out_dim: usize,
400        config: &Option<QuantizedConfig>,
401        bias: bool,
402        comm: &Arc<crate::Comm>,
403        vb: ShardedVarBuilder,
404    ) -> Result<Arc<dyn QuantMethod>> {
405        let rank = comm.rank();
406        let world_size = comm.world_size();
407        let shard = shard(0, rank, world_size);
408
409        Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
410    }
411
412    #[allow(clippy::new_ret_no_self)]
413    pub fn new_matformer(
414        in_dim: usize,
415        out_dim: usize,
416        orig_intermediate_size: usize,
417        config: &Option<QuantizedConfig>,
418        bias: bool,
419        comm: &Arc<crate::Comm>,
420        vb: ShardedVarBuilder,
421    ) -> Result<Arc<dyn QuantMethod>> {
422        let rank = comm.rank();
423        let world_size = comm.world_size();
424        let shard = shard(0, rank, world_size);
425
426        let base_vb = vb.clone();
427        let vb = if should_apply_immediate_isq(&vb) {
428            vb.set_device(Device::Cpu)
429        } else {
430            vb
431        };
432
433        if config.is_some() {
434            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
435        }
436
437        // Handle the case where the layer is dummy (no tensors)
438        let weight = if !vb.contains_tensor("weight") {
439            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
440            Arc::new(layer) as Arc<dyn QuantMethod>
441        } else {
442            let weight = vb
443                .get_with_hints(
444                    (orig_intermediate_size, in_dim),
445                    "weight",
446                    Default::default(),
447                )?
448                .i((..out_dim, ..))?
449                .contiguous()?;
450
451            let weight = shard.apply_to(&weight)?;
452            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
453
454            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
455                Linear::new(weight, None),
456            ))?;
457            Arc::new(layer) as Arc<dyn QuantMethod>
458        };
459
460        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
461        let bias = if bias && vb.contains_tensor("bias") {
462            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
463        } else {
464            None
465        };
466
467        let this_unquant = Arc::new(Self { weight, bias });
468        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
469        Ok(this)
470    }
471
472    pub fn new_merged(
473        in_dim: usize,
474        out_dim: usize,
475        chunks: usize,
476        config: &Option<QuantizedConfig>,
477        bias: bool,
478        comm: &Arc<crate::Comm>,
479        vb: ShardedVarBuilder,
480    ) -> Result<Vec<Arc<dyn QuantMethod>>> {
481        let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
482        for chunk_idx in 0..chunks {
483            let layer = ColumnParallelLayer::new_with_shard(
484                in_dim,
485                out_dim,
486                config,
487                bias,
488                comm,
489                shard(
490                    0,
491                    chunk_idx * comm.world_size() + comm.rank(),
492                    chunks * comm.world_size(),
493                ),
494                vb.clone(),
495            )?;
496            vec_layers.push(layer);
497        }
498        Ok(vec_layers)
499    }
500}
501
502impl QuantMethod for ColumnParallelLayer {
503    fn new(_method: QuantMethodConfig) -> Result<Self>
504    where
505        Self: Sized,
506    {
507        candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
508    }
509
510    fn forward(&self, a: &Tensor) -> Result<Tensor> {
511        let mut xs = self.weight.forward(a)?;
512        if let Some(bias) = &self.bias {
513            xs = xs.broadcast_add(bias)?;
514        }
515        Ok(xs)
516    }
517
518    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
519        let weight = self.weight.add_delta_w(delta)?;
520        Ok(Arc::new(Self {
521            weight,
522            bias: self.bias.clone(),
523        }))
524    }
525
526    fn dequantize_w(&self) -> Result<Tensor> {
527        self.weight.dequantize_w()
528    }
529
530    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
531        self.weight.dtype_and_device()
532    }
533
534    fn begin_track_stats(&mut self) -> Result<()> {
535        Arc::get_mut(&mut self.weight)
536            .context("Failed to get &mut to weight")?
537            .begin_track_stats()
538    }
539
540    fn end_track_stats(&self) -> Result<Tensor> {
541        self.weight.end_track_stats()
542    }
543
544    fn quantized_act_type(&self) -> Option<candle_core::DType> {
545        self.weight.quantized_act_type()
546    }
547
548    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
549        self.weight.unquant_weight_bias()
550    }
551
552    fn apply_isq(
553        self: Arc<Self>,
554        dtype: Option<crate::IsqType>,
555        device: candle_core::Device,
556        n_quantized: &std::sync::atomic::AtomicUsize,
557        imatrix_weight: Option<Vec<f32>>,
558        guard: QuantizeOntoGuard,
559    ) -> Result<Arc<dyn QuantMethod>> {
560        let weight =
561            self.weight
562                .clone()
563                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
564        let bias = match &self.bias {
565            Some(b) => {
566                let (dtype, device) = weight.dtype_and_device();
567                Some(b.to_device(&device)?.to_dtype(dtype)?)
568            }
569            None => None,
570        };
571        Ok(Arc::new(Self { weight, bias }))
572    }
573
574    fn is_distributed(&self) -> Option<DistributedKind> {
575        Some(DistributedKind::ColumnParallel)
576    }
577}
578
579impl QuantizedSerde for ColumnParallelLayer {
580    fn isq_serde_supported(&self) -> bool {
581        self.weight.isq_serde_supported()
582    }
583    fn name(&self) -> &'static str {
584        self.weight.name()
585    }
586    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
587        self.weight.serialize_with_bias(self.bias.clone())
588    }
589    fn deserialize(
590        data: std::borrow::Cow<[u8]>,
591        device: &candle_core::Device,
592        _comm: &Arc<crate::Comm>,
593        guard: QuantizeOntoGuard,
594    ) -> Result<Arc<dyn QuantMethod>>
595    where
596        Self: Sized,
597    {
598        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
599        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
600        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
601            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
602            QuantizedSerdeType::Unquant => {
603                UnquantLinear::deserialize_ext_bias(data, device, guard)?
604            }
605            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
606            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
607            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
608        };
609        Ok(Arc::new(Self { weight, bias }))
610    }
611}
612
613#[derive(Debug)]
614/// This layer has no parallelization
615pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
616
617impl ReplicatedLayer {
618    pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
619        let dev = lin.weight().device().clone();
620        let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
621        let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
622        Ok(this)
623    }
624
625    #[allow(clippy::new_ret_no_self)]
626    pub fn new(
627        in_dim: usize,
628        out_dim: usize,
629        config: &Option<QuantizedConfig>,
630        bias: bool,
631        vb: ShardedVarBuilder,
632    ) -> Result<Arc<dyn QuantMethod>> {
633        let base_vb = vb.clone();
634        let vb = if should_apply_immediate_isq(&vb) {
635            vb.set_device(Device::Cpu)
636        } else {
637            vb
638        };
639
640        let layer = if let Some(quant_conf) = &config {
641            match quant_conf {
642                QuantizedConfig::GptqAwq { .. } => {
643                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
644                }
645                QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
646                    in_dim,
647                    out_dim,
648                    quant_conf,
649                    bias,
650                    Default::default(),
651                    vb.clone(),
652                )?,
653                QuantizedConfig::Bitsandbytes { .. } => {
654                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
655                }
656                QuantizedConfig::Afq { .. } => {
657                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
658                }
659                QuantizedConfig::MXFP4 {} => {
660                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
661                }
662            }
663        } else {
664            // Handle the case where the layer is dummy (no tensors)
665            if !vb.contains_tensor("weight") {
666                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
667                Arc::new(layer) as Arc<dyn QuantMethod>
668            } else {
669                let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
670                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
671
672                let bias = if bias {
673                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
674                } else {
675                    None
676                };
677                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
678                    Linear::new(weight, bias),
679                ))?;
680                Arc::new(layer) as Arc<dyn QuantMethod>
681            }
682        };
683
684        let this_unquant = Arc::new(Self(layer));
685        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
686        Ok(this)
687    }
688
689    #[allow(clippy::new_ret_no_self)]
690    pub fn new_layers_matformer_indices(
691        in_dim: usize,
692        out_dim: usize,
693        kept_layers_indices: Option<&Tensor>,
694        orig_num_hidden_layers: usize,
695        config: &Option<QuantizedConfig>,
696        bias: bool,
697        vb: ShardedVarBuilder,
698    ) -> Result<Arc<dyn QuantMethod>> {
699        let base_vb = vb.clone();
700        let vb = if should_apply_immediate_isq(&vb) {
701            vb.set_device(Device::Cpu)
702        } else {
703            vb
704        };
705
706        let layer = if let Some(quant_conf) = &config {
707            if kept_layers_indices.is_some() {
708                candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
709            }
710
711            match quant_conf {
712                QuantizedConfig::GptqAwq { .. } => {
713                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
714                }
715                QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
716                    in_dim,
717                    out_dim,
718                    quant_conf,
719                    bias,
720                    Default::default(),
721                    vb.clone(),
722                )?,
723                QuantizedConfig::Bitsandbytes { .. } => {
724                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
725                }
726                QuantizedConfig::Afq { .. } => {
727                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
728                }
729                QuantizedConfig::MXFP4 {} => {
730                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
731                }
732            }
733        } else {
734            // Handle the case where the layer is dummy (no tensors)
735            if !vb.contains_tensor("weight") {
736                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
737                Arc::new(layer) as Arc<dyn QuantMethod>
738            } else {
739                let mut weight =
740                    vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
741
742                if let Some(kept_layers_indices) = &kept_layers_indices {
743                    let weight_reshaped = weight.reshape((
744                        orig_num_hidden_layers,
745                        weight.dim(0)? / orig_num_hidden_layers,
746                        weight.dim(1)?,
747                    ))?;
748
749                    weight = weight_reshaped
750                        .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
751                        .reshape(((), weight_reshaped.dim(D::Minus1)?))?
752                        .contiguous()?;
753                }
754
755                weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
756
757                let bias = if bias {
758                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
759                } else {
760                    None
761                };
762                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
763                    Linear::new(weight, bias),
764                ))?;
765                Arc::new(layer) as Arc<dyn QuantMethod>
766            }
767        };
768
769        let this_unquant = Arc::new(Self(layer));
770        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
771        Ok(this)
772    }
773}
774
775impl QuantMethod for ReplicatedLayer {
776    fn new(_method: QuantMethodConfig) -> Result<Self>
777    where
778        Self: Sized,
779    {
780        candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
781    }
782
783    fn forward(&self, a: &Tensor) -> Result<Tensor> {
784        self.0.forward(a)
785    }
786
787    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
788        self.0.add_delta_w(delta)
789    }
790
791    fn dequantize_w(&self) -> Result<Tensor> {
792        self.0.dequantize_w()
793    }
794
795    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
796        self.0.dtype_and_device()
797    }
798
799    fn begin_track_stats(&mut self) -> Result<()> {
800        Arc::get_mut(&mut self.0)
801            .context("Failed to get &mut to weight")?
802            .begin_track_stats()
803    }
804
805    fn end_track_stats(&self) -> Result<Tensor> {
806        self.0.end_track_stats()
807    }
808
809    fn quantized_act_type(&self) -> Option<candle_core::DType> {
810        self.0.quantized_act_type()
811    }
812
813    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
814        self.0.unquant_weight_bias()
815    }
816
817    fn apply_isq(
818        self: Arc<Self>,
819        dtype: Option<crate::IsqType>,
820        device: candle_core::Device,
821        n_quantized: &std::sync::atomic::AtomicUsize,
822        imatrix_weight: Option<Vec<f32>>,
823        guard: QuantizeOntoGuard,
824    ) -> Result<Arc<dyn QuantMethod>> {
825        self.0
826            .clone()
827            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
828    }
829
830    fn is_distributed(&self) -> Option<DistributedKind> {
831        Some(DistributedKind::Replicated)
832    }
833}
834
835impl QuantizedSerde for ReplicatedLayer {
836    fn isq_serde_supported(&self) -> bool {
837        self.0.isq_serde_supported()
838    }
839    fn name(&self) -> &'static str {
840        self.0.name()
841    }
842    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
843        self.0.serialize()
844    }
845    fn deserialize(
846        data: std::borrow::Cow<[u8]>,
847        device: &candle_core::Device,
848        comm: &Arc<crate::Comm>,
849        guard: QuantizeOntoGuard,
850    ) -> Result<Arc<dyn QuantMethod>>
851    where
852        Self: Sized,
853    {
854        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
855        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
856        let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
857            QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
858            QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
859            QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
860            QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
861            QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
862        };
863        Ok(Arc::new(Self(deserialized)))
864    }
865}
866
867#[derive(Debug)]
868pub struct PackedExperts {
869    pub gate_proj: Vec<Arc<dyn QuantMethod>>,
870    pub up_proj: Vec<Arc<dyn QuantMethod>>,
871    pub down_proj: Vec<Arc<dyn QuantMethod>>,
872}
873
874impl PackedExperts {
875    /// Note: we only support AFQ and unquantized here because they are the only ones that support indexed.
876    #[allow(clippy::too_many_arguments)]
877    pub fn new(
878        num_local_experts: usize,
879        hidden_size: usize,
880        intermediate_size: usize,
881        config: &Option<QuantizedConfig>,
882        bias: bool,
883        comm: &Arc<crate::Comm>,
884        vb: ShardedVarBuilder,
885    ) -> Result<Self> {
886        if bias {
887            candle_core::bail!("PackedExperts does not support bias.");
888        }
889
890        let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
891            // GPTQ and BNB do not support tensor parallelism
892            if comm.world_size() != 1 {
893                candle_core::bail!(
894                    "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
895                    comm.world_size()
896                );
897            }
898
899            match quant_conf {
900                QuantizedConfig::Afq { .. } => {
901                    if !vb.contains_tensor("gate_up_proj")
902                        || !vb.contains_tensor("gate_up_proj.weight")
903                    {
904                        candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
905                    }
906
907                    let base_vb = vb.clone();
908
909                    let vb_gate_proj = if should_apply_immediate_isq(&vb) {
910                        vb.pp("gate_proj").set_device(Device::Cpu)
911                    } else {
912                        vb.pp("gate_proj")
913                    };
914                    let vb_up_proj = if should_apply_immediate_isq(&vb) {
915                        vb.pp("up_proj").set_device(Device::Cpu)
916                    } else {
917                        vb.pp("up_proj")
918                    };
919                    let vb_down_proj = if should_apply_immediate_isq(&vb) {
920                        vb.pp("down_proj").set_device(Device::Cpu)
921                    } else {
922                        vb.pp("down_proj")
923                    };
924                    let mut gate_proj = AfqLayer::afq_packed_linear_b(
925                        num_local_experts,
926                        hidden_size,
927                        intermediate_size,
928                        quant_conf,
929                        bias,
930                        vb_gate_proj,
931                    )?;
932                    let mut up_proj = AfqLayer::afq_packed_linear_b(
933                        num_local_experts,
934                        hidden_size,
935                        intermediate_size,
936                        quant_conf,
937                        bias,
938                        vb_up_proj,
939                    )?;
940                    let mut down_proj = AfqLayer::afq_packed_linear_b(
941                        num_local_experts,
942                        intermediate_size,
943                        hidden_size,
944                        quant_conf,
945                        bias,
946                        vb_down_proj,
947                    )?;
948
949                    gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
950                    up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
951                    down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
952
953                    (vec![gate_proj], vec![up_proj], vec![down_proj])
954                }
955                QuantizedConfig::Fp8 { weight_block_size } => {
956                    // FP8 quantization for PackedExperts
957                    // Keep weights as FP8 using BlockwiseFP8Linear to leverage native FP8 GEMM
958                    if weight_block_size.len() != 2 {
959                        candle_core::bail!(
960                            "Expected weight_block_size to have length 2, got {weight_block_size:?}"
961                        );
962                    }
963
964                    // Check if we have stacked format (gate_up_proj) or per-expert format
965                    // Note: vb already has the "experts" prefix from the caller (experts.rs)
966                    let is_stacked_format = vb.contains_tensor("gate_up_proj");
967
968                    if is_stacked_format {
969                        // Stacked format: load FP8 tensors and split
970                        let has_fp8_scales = vb.contains_tensor("gate_up_proj.weight_scale_inv");
971
972                        if has_fp8_scales {
973                            // Load gate_up_proj FP8 tensor and scale
974                            let gate_up_fp8 = vb.get_with_hints_dtype(
975                                (num_local_experts, hidden_size, intermediate_size * 2),
976                                "gate_up_proj",
977                                Default::default(),
978                                candle_core::DType::F8E4M3,
979                            )?;
980                            let gate_up_scale = vb.get_with_hints_dtype(
981                                (
982                                    num_local_experts,
983                                    hidden_size.div_ceil(weight_block_size[0]),
984                                    (intermediate_size * 2).div_ceil(weight_block_size[1]),
985                                ),
986                                "gate_up_proj.weight_scale_inv",
987                                Default::default(),
988                                candle_core::DType::F32,
989                            )?;
990
991                            // Load down_proj FP8 tensor and scale
992                            let down_fp8 = vb.get_with_hints_dtype(
993                                (num_local_experts, intermediate_size, hidden_size),
994                                "down_proj",
995                                Default::default(),
996                                candle_core::DType::F8E4M3,
997                            )?;
998                            let down_scale = vb.get_with_hints_dtype(
999                                (
1000                                    num_local_experts,
1001                                    intermediate_size.div_ceil(weight_block_size[0]),
1002                                    hidden_size.div_ceil(weight_block_size[1]),
1003                                ),
1004                                "down_proj.weight_scale_inv",
1005                                Default::default(),
1006                                candle_core::DType::F32,
1007                            )?;
1008
1009                            // Split and create individual BlockwiseFP8Linear for each expert
1010                            let mut gs = Vec::new();
1011                            let mut us = Vec::new();
1012                            let mut ds = Vec::new();
1013
1014                            for i in 0..num_local_experts {
1015                                // Extract this expert's weights
1016                                let gate_up_expert =
1017                                    gate_up_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1018                                let gate_up_scale_expert = gate_up_scale.i(i)?.contiguous()?;
1019                                let down_expert = down_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1020                                let down_scale_expert = down_scale.i(i)?.contiguous()?;
1021
1022                                // Split gate_up into gate and up
1023                                let gate_expert = gate_up_expert.narrow(0, 0, intermediate_size)?;
1024                                let up_expert = gate_up_expert.narrow(
1025                                    0,
1026                                    intermediate_size,
1027                                    intermediate_size,
1028                                )?;
1029
1030                                // Split scales
1031                                let gate_scale_expert = gate_up_scale_expert.narrow(
1032                                    1,
1033                                    0,
1034                                    intermediate_size.div_ceil(weight_block_size[1]),
1035                                )?;
1036                                let up_scale_expert = gate_up_scale_expert.narrow(
1037                                    1,
1038                                    intermediate_size.div_ceil(weight_block_size[1]),
1039                                    intermediate_size.div_ceil(weight_block_size[1]),
1040                                )?;
1041
1042                                // Create BlockwiseFP8Linear for each projection
1043                                use crate::blockwise_fp8::BlockwiseFP8Linear;
1044                                use crate::QuantMethodConfig;
1045
1046                                let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1047                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1048                                        weight: gate_expert,
1049                                        weight_scale_inv: gate_scale_expert.transpose(0, 1)?,
1050                                        bias: None,
1051                                        dequant_dtype: vb.dtype(),
1052                                        weight_block_size: weight_block_size.clone(),
1053                                    })?,
1054                                );
1055                                let up_layer: Arc<dyn QuantMethod> = Arc::new(
1056                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1057                                        weight: up_expert,
1058                                        weight_scale_inv: up_scale_expert.transpose(0, 1)?,
1059                                        bias: None,
1060                                        dequant_dtype: vb.dtype(),
1061                                        weight_block_size: weight_block_size.clone(),
1062                                    })?,
1063                                );
1064                                let down_layer: Arc<dyn QuantMethod> = Arc::new(
1065                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1066                                        weight: down_expert,
1067                                        weight_scale_inv: down_scale_expert.transpose(0, 1)?,
1068                                        bias: None,
1069                                        dequant_dtype: vb.dtype(),
1070                                        weight_block_size: weight_block_size.clone(),
1071                                    })?,
1072                                );
1073
1074                                gs.push(gate_layer);
1075                                us.push(up_layer);
1076                                ds.push(down_layer);
1077                            }
1078
1079                            (gs, us, ds)
1080                        } else {
1081                            candle_core::bail!(
1082                                "PackedExperts with FP8 requires weight_scale_inv tensors"
1083                            );
1084                        }
1085                    } else {
1086                        // Per-expert format: load each expert individually
1087                        let mut gs = Vec::new();
1088                        let mut us = Vec::new();
1089                        let mut ds = Vec::new();
1090
1091                        for i in 0..num_local_experts {
1092                            let expert_vb = vb.pp(i);
1093
1094                            // Load FP8 weights and scales for each projection
1095                            let gate_fp8 = expert_vb.get_with_hints_dtype(
1096                                (intermediate_size, hidden_size),
1097                                "gate_proj.weight",
1098                                Default::default(),
1099                                candle_core::DType::F8E4M3,
1100                            )?;
1101                            let gate_scale = expert_vb.get_with_hints_dtype(
1102                                (
1103                                    intermediate_size.div_ceil(weight_block_size[0]),
1104                                    hidden_size.div_ceil(weight_block_size[1]),
1105                                ),
1106                                "gate_proj.weight_scale_inv",
1107                                Default::default(),
1108                                candle_core::DType::F32,
1109                            )?;
1110
1111                            let up_fp8 = expert_vb.get_with_hints_dtype(
1112                                (intermediate_size, hidden_size),
1113                                "up_proj.weight",
1114                                Default::default(),
1115                                candle_core::DType::F8E4M3,
1116                            )?;
1117                            let up_scale = expert_vb.get_with_hints_dtype(
1118                                (
1119                                    intermediate_size.div_ceil(weight_block_size[0]),
1120                                    hidden_size.div_ceil(weight_block_size[1]),
1121                                ),
1122                                "up_proj.weight_scale_inv",
1123                                Default::default(),
1124                                candle_core::DType::F32,
1125                            )?;
1126
1127                            let down_fp8 = expert_vb.get_with_hints_dtype(
1128                                (hidden_size, intermediate_size),
1129                                "down_proj.weight",
1130                                Default::default(),
1131                                candle_core::DType::F8E4M3,
1132                            )?;
1133                            let down_scale = expert_vb.get_with_hints_dtype(
1134                                (
1135                                    hidden_size.div_ceil(weight_block_size[0]),
1136                                    intermediate_size.div_ceil(weight_block_size[1]),
1137                                ),
1138                                "down_proj.weight_scale_inv",
1139                                Default::default(),
1140                                candle_core::DType::F32,
1141                            )?;
1142
1143                            // Create BlockwiseFP8Linear for each projection
1144                            use crate::blockwise_fp8::BlockwiseFP8Linear;
1145                            use crate::QuantMethodConfig;
1146
1147                            let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1148                                BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1149                                    weight: gate_fp8,
1150                                    weight_scale_inv: gate_scale,
1151                                    bias: None,
1152                                    dequant_dtype: vb.dtype(),
1153                                    weight_block_size: weight_block_size.clone(),
1154                                })?,
1155                            );
1156                            let up_layer: Arc<dyn QuantMethod> = Arc::new(BlockwiseFP8Linear::new(
1157                                QuantMethodConfig::BlockwiseFP8 {
1158                                    weight: up_fp8,
1159                                    weight_scale_inv: up_scale,
1160                                    bias: None,
1161                                    dequant_dtype: vb.dtype(),
1162                                    weight_block_size: weight_block_size.clone(),
1163                                },
1164                            )?);
1165                            let down_layer: Arc<dyn QuantMethod> = Arc::new(
1166                                BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1167                                    weight: down_fp8,
1168                                    weight_scale_inv: down_scale,
1169                                    bias: None,
1170                                    dequant_dtype: vb.dtype(),
1171                                    weight_block_size: weight_block_size.clone(),
1172                                })?,
1173                            );
1174
1175                            gs.push(gate_layer);
1176                            us.push(up_layer);
1177                            ds.push(down_layer);
1178                        }
1179
1180                        (gs, us, ds)
1181                    }
1182                }
1183                _ => candle_core::bail!(
1184                    "PackedExperts with quantization config only allows AFQ or FP8 quantization"
1185                ),
1186            }
1187        } else if !vb.contains_tensor("gate_up_proj") {
1188            // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
1189            let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
1190            let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
1191            let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
1192            for _ in 0..num_local_experts {
1193                gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1194                us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1195                ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1196            }
1197            (gs, us, ds)
1198        } else {
1199            // Parallelized like:
1200            // Each gpu holds all experts.
1201            // Gate/Up proj is parallelized on dim 2 (column)
1202            // Down proj is parallelized on dim 1 (row)
1203            // All reduce at the end.
1204
1205            // Handle the case where the layer is dummy (no tensors)
1206            let gate_up_block_size = intermediate_size / comm.world_size();
1207            let gate_up_start = gate_up_block_size * comm.rank();
1208
1209            // Gate is right before Up in the gate_up
1210            let shard_gate = Shard::Offset {
1211                dim: 2,
1212                offset: gate_up_start,
1213                len: gate_up_block_size,
1214            };
1215            let shard_up = Shard::Offset {
1216                dim: 2,
1217                offset: intermediate_size + gate_up_start,
1218                len: gate_up_block_size,
1219            };
1220            let shard_down = Shard::Simple {
1221                dim: 1,
1222                rank: comm.rank(),
1223                world_size: comm.world_size(),
1224            };
1225
1226            let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
1227                vb.pp("gate_up_proj").set_device(Device::Cpu)
1228            } else {
1229                vb.pp("gate_up_proj")
1230            };
1231            let vb_down_proj = if should_apply_immediate_isq(&vb) {
1232                vb.pp("down_proj").set_device(Device::Cpu)
1233            } else {
1234                vb.pp("down_proj")
1235            };
1236
1237            let gate_proj = vb
1238                .get_with_hints(
1239                    (num_local_experts, hidden_size, intermediate_size * 2),
1240                    "gate_up_proj",
1241                    shard_gate,
1242                )?
1243                .t()?
1244                .contiguous()?;
1245            let up_proj = vb
1246                .get_with_hints(
1247                    (num_local_experts, hidden_size, intermediate_size * 2),
1248                    "gate_up_proj",
1249                    shard_up,
1250                )?
1251                .t()?
1252                .contiguous()?;
1253            let down_proj = vb
1254                .get_with_hints(
1255                    (num_local_experts, intermediate_size, hidden_size),
1256                    "down_proj",
1257                    shard_down,
1258                )?
1259                .t()?
1260                .contiguous()?;
1261
1262            let gc = gate_proj.chunk(num_local_experts, 0)?;
1263            let uc = up_proj.chunk(num_local_experts, 0)?;
1264            let dc = down_proj.chunk(num_local_experts, 0)?;
1265            drop((gate_proj, up_proj, down_proj));
1266
1267            let mut gs = Vec::new();
1268            let mut us = Vec::new();
1269            let mut ds = Vec::new();
1270            for ((mut gate_proj, mut up_proj), mut down_proj) in
1271                gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
1272            {
1273                gate_proj = gate_proj.squeeze(0)?;
1274                up_proj = up_proj.squeeze(0)?;
1275                down_proj = down_proj.squeeze(0)?;
1276                let gate_proj = merge_lora_weights(
1277                    &vb,
1278                    gate_proj,
1279                    hidden_size,
1280                    intermediate_size * 2,
1281                    shard_gate,
1282                )?;
1283                let up_proj =
1284                    merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1285                let down_proj =
1286                    merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1287
1288                let mut gate_proj: Arc<dyn QuantMethod> =
1289                    Arc::new(<UnquantLinear as QuantMethod>::new(
1290                        QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1291                    )?);
1292                gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1293                let mut up_proj: Arc<dyn QuantMethod> =
1294                    Arc::new(<UnquantLinear as QuantMethod>::new(
1295                        QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1296                    )?);
1297                up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1298                let mut down_proj: Arc<dyn QuantMethod> =
1299                    Arc::new(<UnquantLinear as QuantMethod>::new(
1300                        QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1301                    )?);
1302                down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1303                gs.push(gate_proj);
1304                us.push(up_proj);
1305                ds.push(down_proj);
1306            }
1307            (gs, us, ds)
1308        };
1309
1310        Ok(Self {
1311            gate_proj,
1312            up_proj,
1313            down_proj,
1314        })
1315    }
1316}
1317
1318pub struct FusedExperts {
1319    pub fused_gate_proj: Arc<dyn QuantMethod>,
1320    pub fused_up_proj: Arc<dyn QuantMethod>,
1321    pub fused_down_proj: Arc<dyn QuantMethod>,
1322}
1323
1324impl FusedExperts {
1325    pub fn new(
1326        hidden_size: usize,
1327        moe_intermediate_size: usize,
1328        num_experts: usize,
1329        quantization_config: &Option<QuantizedConfig>,
1330        vb: ShardedVarBuilder,
1331    ) -> Result<Self> {
1332        // Detect if weights are in stacked format (e.g., Qwen3 VL MoE):
1333        // - experts.gate_up_proj: (num_experts, hidden_size, intermediate_size * 2)
1334        // - experts.down_proj: (num_experts, intermediate_size, hidden_size)
1335        // Or per-expert format (e.g., Qwen3 MoE):
1336        // - experts.{i}.gate_proj.weight, experts.{i}.up_proj.weight, experts.{i}.down_proj.weight
1337        let experts_vb = vb.pp("experts");
1338        let is_stacked_format = experts_vb.contains_tensor("gate_up_proj");
1339
1340        let (fused_gate_proj, fused_up_proj, fused_down_proj) = if matches!(
1341            &quantization_config,
1342            Some(QuantizedConfig::Afq { .. })
1343        ) {
1344            let quantization_config = quantization_config.as_ref().unwrap();
1345
1346            let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1347                num_experts,
1348                hidden_size,
1349                moe_intermediate_size,
1350                quantization_config,
1351                false,
1352                vb.pp("switch_mlp.gate_proj"),
1353            )?;
1354            let fused_up_proj = AfqLayer::afq_packed_linear_b(
1355                num_experts,
1356                hidden_size,
1357                moe_intermediate_size,
1358                quantization_config,
1359                false,
1360                vb.pp("switch_mlp.up_proj"),
1361            )?;
1362            let fused_down_proj = AfqLayer::afq_packed_linear_b(
1363                num_experts,
1364                moe_intermediate_size,
1365                hidden_size,
1366                quantization_config,
1367                false,
1368                vb.pp("switch_mlp.down_proj"),
1369            )?;
1370
1371            (fused_gate_proj, fused_up_proj, fused_down_proj)
1372        } else if is_stacked_format
1373            && matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. }))
1374        {
1375            // Stacked format with FP8 quantization
1376            // Keep weights as FP8 using BlockwiseFP8 to leverage native FP8 GEMM in gather_forward
1377            let has_fp8_scales = experts_vb.contains_tensor("gate_up_proj.weight_scale_inv");
1378
1379            if has_fp8_scales {
1380                let weight_block_size = match quantization_config {
1381                    Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1382                    _ => unreachable!(),
1383                };
1384
1385                if weight_block_size.len() != 2 {
1386                    candle_core::bail!(
1387                        "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1388                    );
1389                }
1390
1391                // Load gate_up_proj FP8 tensor and scale
1392                // Shape: [num_experts, hidden_size, intermediate_size * 2]
1393                let gate_up_fp8 = experts_vb.get_with_hints_dtype(
1394                    (num_experts, hidden_size, moe_intermediate_size * 2),
1395                    "gate_up_proj",
1396                    Default::default(),
1397                    candle_core::DType::F8E4M3,
1398                )?;
1399                let gate_up_scale = experts_vb.get_with_hints_dtype(
1400                    (
1401                        num_experts,
1402                        hidden_size.div_ceil(weight_block_size[0]),
1403                        (moe_intermediate_size * 2).div_ceil(weight_block_size[1]),
1404                    ),
1405                    "gate_up_proj.weight_scale_inv",
1406                    Default::default(),
1407                    candle_core::DType::F32,
1408                )?;
1409
1410                // Load down_proj FP8 tensor and scale
1411                // Shape: [num_experts, intermediate_size, hidden_size]
1412                let down_fp8 = experts_vb.get_with_hints_dtype(
1413                    (num_experts, moe_intermediate_size, hidden_size),
1414                    "down_proj",
1415                    Default::default(),
1416                    candle_core::DType::F8E4M3,
1417                )?;
1418                let down_scale = experts_vb.get_with_hints_dtype(
1419                    (
1420                        num_experts,
1421                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1422                        hidden_size.div_ceil(weight_block_size[1]),
1423                    ),
1424                    "down_proj.weight_scale_inv",
1425                    Default::default(),
1426                    candle_core::DType::F32,
1427                )?;
1428
1429                // Split gate_up into gate and up
1430                let gate_fp8 = gate_up_fp8.narrow(2, 0, moe_intermediate_size)?;
1431                let up_fp8 = gate_up_fp8.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1432
1433                // Split scales similarly
1434                let gate_scale = gate_up_scale.narrow(
1435                    2,
1436                    0,
1437                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1438                )?;
1439                let up_scale = gate_up_scale.narrow(
1440                    2,
1441                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1442                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1443                )?;
1444
1445                // Transpose to match expected format: [num_experts, N, K]
1446                // gate/up: [num_experts, hidden_size, intermediate_size] -> [num_experts, intermediate_size, hidden_size]
1447                let gate_fp8 = gate_fp8.transpose(1, 2)?.contiguous()?;
1448                let up_fp8 = up_fp8.transpose(1, 2)?.contiguous()?;
1449                // down: [num_experts, intermediate_size, hidden_size] -> [num_experts, hidden_size, intermediate_size]
1450                let down_fp8 = down_fp8.transpose(1, 2)?.contiguous()?;
1451
1452                // Transpose scales to match weight layout
1453                let gate_scale = gate_scale.transpose(1, 2)?.contiguous()?;
1454                let up_scale = up_scale.transpose(1, 2)?.contiguous()?;
1455                let down_scale = down_scale.transpose(1, 2)?.contiguous()?;
1456
1457                // Create BlockwiseFP8Linear for each projection
1458                let fused_gate_proj =
1459                    blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1460                let fused_up_proj =
1461                    blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1462                let fused_down_proj =
1463                    blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1464
1465                (fused_gate_proj, fused_up_proj, fused_down_proj)
1466            } else {
1467                // FP8 config but no scale tensors - weights are actually unquantized
1468                tracing::warn!(
1469                        "FP8 quantization config specified but no scale tensors found for stacked MoE experts. \
1470                        Loading as unquantized."
1471                    );
1472                let gate_up_proj = experts_vb.get(
1473                    (num_experts, hidden_size, moe_intermediate_size * 2),
1474                    "gate_up_proj",
1475                )?;
1476                let down_proj_packed = experts_vb.get(
1477                    (num_experts, moe_intermediate_size, hidden_size),
1478                    "down_proj",
1479                )?;
1480
1481                // Split gate_up_proj into gate_proj and up_proj along the last dimension
1482                let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1483                let up_proj =
1484                    gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1485
1486                // Transpose dims 1 and 2 to match GGUF format
1487                let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1488                let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1489                let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1490
1491                let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1492                    QuantMethodConfig::Unquantized(Linear::new(gate_proj.clone(), None)),
1493                )?);
1494                let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1495                    QuantMethodConfig::Unquantized(Linear::new(up_proj.clone(), None)),
1496                )?);
1497                let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1498                    QuantMethodConfig::Unquantized(Linear::new(down_proj.clone(), None)),
1499                )?);
1500                // Use apply_immediate_isq_always to ensure ISQ is applied to expert weights
1501                let device = gate_proj.device();
1502                fused_gate_proj = apply_immediate_isq_always(fused_gate_proj, device)?;
1503                fused_up_proj = apply_immediate_isq_always(fused_up_proj, device)?;
1504                fused_down_proj = apply_immediate_isq_always(fused_down_proj, device)?;
1505
1506                (fused_gate_proj, fused_up_proj, fused_down_proj)
1507            }
1508        } else if is_stacked_format {
1509            // Stacked format from safetensors:
1510            // - gate_up_proj: [num_experts, hidden_size, intermediate_size * 2] = [128, 2048, 1536]
1511            // - down_proj: [num_experts, intermediate_size, hidden_size] = [128, 768, 2048]
1512            //
1513            // GGUF/indexed_moe_forward expects:
1514            // - gate/up: [num_experts, intermediate_size, hidden_size] = [128, 768, 2048]
1515            // - down: [num_experts, hidden_size, intermediate_size] = [128, 2048, 768]
1516            let gate_up_proj = experts_vb.get(
1517                (num_experts, hidden_size, moe_intermediate_size * 2),
1518                "gate_up_proj",
1519            )?;
1520            let down_proj_packed = experts_vb.get(
1521                (num_experts, moe_intermediate_size, hidden_size),
1522                "down_proj",
1523            )?;
1524
1525            // Split gate_up_proj into gate_proj and up_proj along the last dimension
1526            // gate_proj: [num_experts, hidden_size, intermediate_size]
1527            // up_proj: [num_experts, hidden_size, intermediate_size]
1528            let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1529            let up_proj = gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1530
1531            // Transpose dims 1 and 2 to match GGUF format:
1532            // gate/up: [num_experts, hidden_size, intermediate_size] -> [num_experts, intermediate_size, hidden_size]
1533            let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1534            let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1535            // down_proj: [num_experts, intermediate_size, hidden_size] -> [num_experts, hidden_size, intermediate_size]
1536            let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1537
1538            let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1539                QuantMethodConfig::Unquantized(Linear::new(gate_proj.clone(), None)),
1540            )?);
1541            let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1542                QuantMethodConfig::Unquantized(Linear::new(up_proj.clone(), None)),
1543            )?);
1544            let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1545                QuantMethodConfig::Unquantized(Linear::new(down_proj.clone(), None)),
1546            )?);
1547            // Use apply_immediate_isq_always to ensure ISQ is applied to expert weights
1548            let device = gate_proj.device();
1549            fused_gate_proj = apply_immediate_isq_always(fused_gate_proj, device)?;
1550            fused_up_proj = apply_immediate_isq_always(fused_up_proj, device)?;
1551            fused_down_proj = apply_immediate_isq_always(fused_down_proj, device)?;
1552
1553            (fused_gate_proj, fused_up_proj, fused_down_proj)
1554        } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1555            // Per-expert format with FP8 quantization
1556            // Keep weights as FP8 using BlockwiseFP8 to leverage native FP8 GEMM in gather_forward
1557            let weight_block_size = match quantization_config {
1558                Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1559                _ => unreachable!(),
1560            };
1561
1562            if weight_block_size.len() != 2 {
1563                candle_core::bail!(
1564                    "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1565                );
1566            }
1567
1568            let mut gate_fp8_vec = Vec::new();
1569            let mut gate_scale_vec = Vec::new();
1570            let mut up_fp8_vec = Vec::new();
1571            let mut up_scale_vec = Vec::new();
1572            let mut down_fp8_vec = Vec::new();
1573            let mut down_scale_vec = Vec::new();
1574
1575            for i in 0..num_experts {
1576                let expert_vb = experts_vb.pp(i);
1577
1578                // Load FP8 weights and scales for each projection
1579                let gate_fp8 = expert_vb.get_with_hints_dtype(
1580                    (moe_intermediate_size, hidden_size),
1581                    "gate_proj.weight",
1582                    Default::default(),
1583                    candle_core::DType::F8E4M3,
1584                )?;
1585                let gate_scale = expert_vb.get_with_hints_dtype(
1586                    (
1587                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1588                        hidden_size.div_ceil(weight_block_size[1]),
1589                    ),
1590                    "gate_proj.weight_scale_inv",
1591                    Default::default(),
1592                    candle_core::DType::F32,
1593                )?;
1594
1595                let up_fp8 = expert_vb.get_with_hints_dtype(
1596                    (moe_intermediate_size, hidden_size),
1597                    "up_proj.weight",
1598                    Default::default(),
1599                    candle_core::DType::F8E4M3,
1600                )?;
1601                let up_scale = expert_vb.get_with_hints_dtype(
1602                    (
1603                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1604                        hidden_size.div_ceil(weight_block_size[1]),
1605                    ),
1606                    "up_proj.weight_scale_inv",
1607                    Default::default(),
1608                    candle_core::DType::F32,
1609                )?;
1610
1611                let down_fp8 = expert_vb.get_with_hints_dtype(
1612                    (hidden_size, moe_intermediate_size),
1613                    "down_proj.weight",
1614                    Default::default(),
1615                    candle_core::DType::F8E4M3,
1616                )?;
1617                let down_scale = expert_vb.get_with_hints_dtype(
1618                    (
1619                        hidden_size.div_ceil(weight_block_size[0]),
1620                        moe_intermediate_size.div_ceil(weight_block_size[1]),
1621                    ),
1622                    "down_proj.weight_scale_inv",
1623                    Default::default(),
1624                    candle_core::DType::F32,
1625                )?;
1626
1627                gate_fp8_vec.push(gate_fp8);
1628                gate_scale_vec.push(gate_scale);
1629                up_fp8_vec.push(up_fp8);
1630                up_scale_vec.push(up_scale);
1631                down_fp8_vec.push(down_fp8);
1632                down_scale_vec.push(down_scale);
1633            }
1634
1635            // Stack into [num_experts, N, K]
1636            let gate_fp8 = Tensor::stack(&gate_fp8_vec, 0)?;
1637            let gate_scale = Tensor::stack(&gate_scale_vec, 0)?;
1638            let up_fp8 = Tensor::stack(&up_fp8_vec, 0)?;
1639            let up_scale = Tensor::stack(&up_scale_vec, 0)?;
1640            let down_fp8 = Tensor::stack(&down_fp8_vec, 0)?;
1641            let down_scale = Tensor::stack(&down_scale_vec, 0)?;
1642
1643            // Create BlockwiseFP8Linear for each projection
1644            let fused_gate_proj =
1645                blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1646            let fused_up_proj =
1647                blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1648            let fused_down_proj =
1649                blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1650
1651            (fused_gate_proj, fused_up_proj, fused_down_proj)
1652        } else {
1653            // Per-expert format: load each expert individually and stack
1654            let mut gate_proj_vec = Vec::new();
1655            let mut up_proj_vec = Vec::new();
1656            let mut down_proj_vec = Vec::new();
1657            for i in 0..num_experts {
1658                let expert_vb = experts_vb.pp(i);
1659                let gate_proj =
1660                    expert_vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1661                let up_proj =
1662                    expert_vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1663                let down_proj =
1664                    expert_vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1665
1666                gate_proj_vec.push(gate_proj);
1667                up_proj_vec.push(up_proj);
1668                down_proj_vec.push(down_proj);
1669            }
1670
1671            let mut gate_proj: Arc<dyn QuantMethod> =
1672                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1673                    Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1674                ))?);
1675            let mut up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1676                QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&up_proj_vec, 0)?, None)),
1677            )?);
1678            let mut down_proj: Arc<dyn QuantMethod> =
1679                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1680                    Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1681                ))?);
1682            // Use experts.0.{proj} prefix to match the actual weight paths for ISQ predicate matching
1683            let expert0_vb = experts_vb.pp("0");
1684            gate_proj = apply_immediate_isq(gate_proj, expert0_vb.pp("gate_proj"))?;
1685            up_proj = apply_immediate_isq(up_proj, expert0_vb.pp("up_proj"))?;
1686            down_proj = apply_immediate_isq(down_proj, expert0_vb.pp("down_proj"))?;
1687
1688            (gate_proj, up_proj, down_proj)
1689        };
1690
1691        Ok(Self {
1692            fused_gate_proj,
1693            fused_up_proj,
1694            fused_down_proj,
1695        })
1696    }
1697}
1698
1699/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
1700pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1701    if comm.world_size() == 1 {
1702        return Shard::default();
1703    }
1704
1705    // Tensor parallelism case
1706
1707    // We may need to replicate the kv heads
1708    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1709        comm.world_size() / total_num_kv_heads
1710    } else {
1711        return Shard::Simple {
1712            dim: 0,
1713            rank: comm.rank(),
1714            world_size: comm.world_size(),
1715        };
1716    };
1717
1718    let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1719    let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1720    Shard::Offset {
1721        dim: 0,
1722        offset: kv_shard_id * head_dim,
1723        len: head_dim,
1724    }
1725}
1726
1727/// Compute the number of KV groups, taking into account KV head replication.
1728pub fn compute_n_kv_groups(
1729    total_num_kv_heads: usize,
1730    num_attention_heads: usize,
1731    comm: &Comm,
1732) -> usize {
1733    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1734        comm.world_size() / total_num_kv_heads
1735    } else {
1736        1
1737    };
1738    if kv_replicate != 0 {
1739        (num_attention_heads / total_num_kv_heads) / kv_replicate
1740    } else {
1741        num_attention_heads / total_num_kv_heads
1742    }
1743}