mistralrs_quant/distributed/
layers.rs

1use std::sync::Arc;
2
3use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
4use candle_nn::Linear;
5
6use crate::{
7    blockwise_fp8::blockwise_fp8_linear_b,
8    distributed,
9    gptq::gptq_linear,
10    lora::merge_lora_weights,
11    should_apply_immediate_isq,
12    utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
13    AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, QuantMethod,
14    QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, QuantizedSerdeType,
15    Shard, ShardedVarBuilder, UnquantLinear,
16};
17
18use super::{Comm, SumAllReduce};
19
20fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
21    Shard::Simple {
22        dim,
23        rank,
24        world_size,
25    }
26}
27
28/// This layer has a weight that is parallelized along the input dimension,
29/// returning the "full" output dimension.
30#[derive(Debug)]
31pub struct RowParallelLayer {
32    weight: Arc<dyn QuantMethod>,
33    bias: Option<Tensor>,
34    all_reduce: distributed::SumAllReduce,
35}
36
37impl RowParallelLayer {
38    #[allow(clippy::new_ret_no_self)]
39    pub fn new(
40        in_dim: usize,
41        out_dim: usize,
42        config: &Option<QuantizedConfig>,
43        bias: bool,
44        comm: &Arc<crate::Comm>,
45        vb: ShardedVarBuilder,
46    ) -> Result<Arc<dyn QuantMethod>> {
47        let rank = comm.rank();
48        let world_size = comm.world_size();
49        let shard = shard(1, rank, world_size);
50
51        let base_vb = vb.clone();
52        let vb = if should_apply_immediate_isq(&vb) {
53            vb.set_device(Device::Cpu)
54        } else {
55            vb
56        };
57
58        let weight = if let Some(quant_conf) = &config {
59            // GPTQ and BNB do not support tensor parallelism
60            if matches!(
61                quant_conf,
62                QuantizedConfig::GptqAwq { .. }
63                    | QuantizedConfig::Bitsandbytes { .. }
64                    | QuantizedConfig::Afq { .. }
65            ) && comm.world_size() != 1
66            {
67                candle_core::bail!(
68                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
69                    comm.world_size()
70                );
71            }
72
73            match quant_conf {
74                QuantizedConfig::GptqAwq { .. } => {
75                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
76                }
77                QuantizedConfig::Fp8 { .. } => {
78                    // NOTE: no bias for fp8 as it might be parallelized
79                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
80                }
81                QuantizedConfig::Bitsandbytes { .. } => {
82                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
83                }
84                QuantizedConfig::Afq { .. } => {
85                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
86                }
87            }
88        } else {
89            // Handle the case where the layer is dummy (no tensors)
90            if !vb.contains_tensor("weight") {
91                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
92                Arc::new(layer) as Arc<dyn QuantMethod>
93            } else {
94                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
95                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
96
97                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
98                    Linear::new(weight, None),
99                ))?;
100                Arc::new(layer) as Arc<dyn QuantMethod>
101            }
102        };
103
104        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
105        let bias = if bias && vb.contains_tensor("bias") {
106            Some(vb.get((out_dim,), "bias")?)
107        } else {
108            None
109        };
110
111        let this_unquant = Arc::new(Self {
112            weight,
113            bias,
114            all_reduce: distributed::SumAllReduce::new(comm),
115        });
116        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
117        Ok(this)
118    }
119
120    #[allow(clippy::new_ret_no_self)]
121    pub fn new_matformer(
122        in_dim: usize,
123        out_dim: usize,
124        orig_intermediate_size: usize,
125        config: &Option<QuantizedConfig>,
126        bias: bool,
127        comm: &Arc<crate::Comm>,
128        vb: ShardedVarBuilder,
129    ) -> Result<Arc<dyn QuantMethod>> {
130        let rank = comm.rank();
131        let world_size = comm.world_size();
132        let shard = shard(1, rank, world_size);
133
134        let base_vb = vb.clone();
135        let vb = if should_apply_immediate_isq(&vb) {
136            vb.set_device(Device::Cpu)
137        } else {
138            vb
139        };
140
141        if config.is_some() {
142            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
143        }
144
145        // Handle the case where the layer is dummy (no tensors)
146        let weight = if !vb.contains_tensor("weight") {
147            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
148            Arc::new(layer) as Arc<dyn QuantMethod>
149        } else {
150            let weight = vb
151                .get_with_hints(
152                    (out_dim, orig_intermediate_size),
153                    "weight",
154                    Default::default(),
155                )?
156                .i((.., ..in_dim))?
157                .contiguous()?;
158
159            let weight = shard.apply_to(&weight)?;
160            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
161
162            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
163                Linear::new(weight, None),
164            ))?;
165            Arc::new(layer) as Arc<dyn QuantMethod>
166        };
167
168        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
169        let bias = if bias && vb.contains_tensor("bias") {
170            Some(vb.get((out_dim,), "bias")?)
171        } else {
172            None
173        };
174
175        let this_unquant = Arc::new(Self {
176            weight,
177            bias,
178            all_reduce: distributed::SumAllReduce::new(comm),
179        });
180        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
181        Ok(this)
182    }
183}
184
185impl QuantMethod for RowParallelLayer {
186    fn new(_method: QuantMethodConfig) -> Result<Self>
187    where
188        Self: Sized,
189    {
190        candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
191    }
192
193    fn forward(&self, a: &Tensor) -> Result<Tensor> {
194        let mut xs = self.weight.forward(a)?;
195        xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
196        if let Some(bias) = &self.bias {
197            xs = xs.broadcast_add(bias)?;
198        }
199        Ok(xs)
200    }
201
202    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
203        let weight = self.weight.add_delta_w(delta)?;
204        Ok(Arc::new(Self {
205            weight,
206            bias: self.bias.clone(),
207            all_reduce: self.all_reduce.clone(),
208        }))
209    }
210
211    fn dequantize_w(&self) -> Result<Tensor> {
212        self.weight.dequantize_w()
213    }
214
215    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
216        self.weight.dtype_and_device()
217    }
218
219    fn begin_track_stats(&mut self) -> Result<()> {
220        Arc::get_mut(&mut self.weight)
221            .context("Failed to get &mut to weight")?
222            .begin_track_stats()
223    }
224
225    fn end_track_stats(&self) -> Result<Tensor> {
226        self.weight.end_track_stats()
227    }
228
229    fn quantized_act_type(&self) -> Option<candle_core::DType> {
230        self.weight.quantized_act_type()
231    }
232
233    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
234        self.weight.unquant_weight_bias()
235    }
236
237    fn apply_isq(
238        self: Arc<Self>,
239        dtype: Option<crate::IsqType>,
240        device: candle_core::Device,
241        n_quantized: &std::sync::atomic::AtomicUsize,
242        imatrix_weight: Option<Vec<f32>>,
243        guard: QuantizeOntoGuard,
244    ) -> Result<Arc<dyn QuantMethod>> {
245        let weight =
246            self.weight
247                .clone()
248                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
249        let bias = match &self.bias {
250            Some(b) => {
251                let (dtype, device) = weight.dtype_and_device();
252                Some(b.to_device(&device)?.to_dtype(dtype)?)
253            }
254            None => None,
255        };
256        Ok(Arc::new(Self {
257            weight,
258            bias,
259            all_reduce: self.all_reduce.clone(),
260        }))
261    }
262
263    fn is_distributed(&self) -> Option<DistributedKind> {
264        Some(DistributedKind::RowParallel)
265    }
266}
267
268impl QuantizedSerde for RowParallelLayer {
269    fn isq_serde_supported(&self) -> bool {
270        self.weight.isq_serde_supported()
271    }
272    fn name(&self) -> &'static str {
273        self.weight.name()
274    }
275    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
276        self.weight.serialize_with_bias(self.bias.clone())
277    }
278    fn deserialize(
279        data: std::borrow::Cow<[u8]>,
280        device: &candle_core::Device,
281        comm: &Arc<crate::Comm>,
282        guard: QuantizeOntoGuard,
283    ) -> Result<Arc<dyn QuantMethod>>
284    where
285        Self: Sized,
286    {
287        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
288        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
289        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
290            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
291            QuantizedSerdeType::Unquant => {
292                UnquantLinear::deserialize_ext_bias(data, device, guard)?
293            }
294            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
295            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
296            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
297        };
298        Ok(Arc::new(Self {
299            weight,
300            bias,
301            all_reduce: SumAllReduce::new(comm),
302        }))
303    }
304}
305
306#[derive(Debug)]
307/// This layer has a weight that is parallelized along the output dimension,
308/// taking the "full" input dimension.
309pub struct ColumnParallelLayer {
310    weight: Arc<dyn QuantMethod>,
311    bias: Option<Tensor>,
312}
313
314impl ColumnParallelLayer {
315    #[allow(clippy::new_ret_no_self)]
316    pub fn new_with_shard(
317        in_dim: usize,
318        out_dim: usize,
319        config: &Option<QuantizedConfig>,
320        bias: bool,
321        comm: &Arc<crate::Comm>,
322        shard: Shard,
323        vb: ShardedVarBuilder,
324    ) -> Result<Arc<dyn QuantMethod>> {
325        let base_vb = vb.clone();
326        let vb = if should_apply_immediate_isq(&vb) {
327            vb.set_device(Device::Cpu)
328        } else {
329            vb
330        };
331
332        let weight = if let Some(quant_conf) = &config {
333            // GPTQ and BNB do not support tensor parallelism
334            if matches!(
335                quant_conf,
336                QuantizedConfig::GptqAwq { .. }
337                    | QuantizedConfig::Bitsandbytes { .. }
338                    | QuantizedConfig::Afq { .. }
339            ) && comm.world_size() != 1
340            {
341                candle_core::bail!(
342                    "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
343                    comm.world_size()
344                );
345            }
346
347            match quant_conf {
348                QuantizedConfig::GptqAwq { .. } => {
349                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
350                }
351                QuantizedConfig::Fp8 { .. } => {
352                    // NOTE: no bias for fp8 as it might be parallelized
353                    blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
354                }
355                QuantizedConfig::Bitsandbytes { .. } => {
356                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
357                }
358                QuantizedConfig::Afq { .. } => {
359                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
360                }
361            }
362        } else {
363            // Handle the case where the layer is dummy (no tensors)
364            if !vb.contains_tensor("weight") {
365                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
366                Arc::new(layer) as Arc<dyn QuantMethod>
367            } else {
368                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
369                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
370
371                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
372                    Linear::new(weight, None),
373                ))?;
374                Arc::new(layer) as Arc<dyn QuantMethod>
375            }
376        };
377
378        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
379        let bias = if bias && vb.contains_tensor("bias") {
380            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
381        } else {
382            None
383        };
384
385        let this_unquant = Arc::new(Self { weight, bias });
386        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
387        Ok(this)
388    }
389
390    #[allow(clippy::new_ret_no_self)]
391    pub fn new(
392        in_dim: usize,
393        out_dim: usize,
394        config: &Option<QuantizedConfig>,
395        bias: bool,
396        comm: &Arc<crate::Comm>,
397        vb: ShardedVarBuilder,
398    ) -> Result<Arc<dyn QuantMethod>> {
399        let rank = comm.rank();
400        let world_size = comm.world_size();
401        let shard = shard(0, rank, world_size);
402
403        Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
404    }
405
406    #[allow(clippy::new_ret_no_self)]
407    pub fn new_matformer(
408        in_dim: usize,
409        out_dim: usize,
410        orig_intermediate_size: usize,
411        config: &Option<QuantizedConfig>,
412        bias: bool,
413        comm: &Arc<crate::Comm>,
414        vb: ShardedVarBuilder,
415    ) -> Result<Arc<dyn QuantMethod>> {
416        let rank = comm.rank();
417        let world_size = comm.world_size();
418        let shard = shard(0, rank, world_size);
419
420        let base_vb = vb.clone();
421        let vb = if should_apply_immediate_isq(&vb) {
422            vb.set_device(Device::Cpu)
423        } else {
424            vb
425        };
426
427        if config.is_some() {
428            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
429        }
430
431        // Handle the case where the layer is dummy (no tensors)
432        let weight = if !vb.contains_tensor("weight") {
433            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
434            Arc::new(layer) as Arc<dyn QuantMethod>
435        } else {
436            let weight = vb
437                .get_with_hints(
438                    (orig_intermediate_size, in_dim),
439                    "weight",
440                    Default::default(),
441                )?
442                .i((..out_dim, ..))?
443                .contiguous()?;
444
445            let weight = shard.apply_to(&weight)?;
446            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
447
448            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
449                Linear::new(weight, None),
450            ))?;
451            Arc::new(layer) as Arc<dyn QuantMethod>
452        };
453
454        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
455        let bias = if bias && vb.contains_tensor("bias") {
456            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
457        } else {
458            None
459        };
460
461        let this_unquant = Arc::new(Self { weight, bias });
462        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
463        Ok(this)
464    }
465
466    pub fn new_merged(
467        in_dim: usize,
468        out_dim: usize,
469        chunks: usize,
470        config: &Option<QuantizedConfig>,
471        bias: bool,
472        comm: &Arc<crate::Comm>,
473        vb: ShardedVarBuilder,
474    ) -> Result<Vec<Arc<dyn QuantMethod>>> {
475        let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
476        for chunk_idx in 0..chunks {
477            let layer = ColumnParallelLayer::new_with_shard(
478                in_dim,
479                out_dim,
480                config,
481                bias,
482                comm,
483                shard(
484                    0,
485                    chunk_idx * comm.world_size() + comm.rank(),
486                    chunks * comm.world_size(),
487                ),
488                vb.clone(),
489            )?;
490            vec_layers.push(layer);
491        }
492        Ok(vec_layers)
493    }
494}
495
496impl QuantMethod for ColumnParallelLayer {
497    fn new(_method: QuantMethodConfig) -> Result<Self>
498    where
499        Self: Sized,
500    {
501        candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
502    }
503
504    fn forward(&self, a: &Tensor) -> Result<Tensor> {
505        let mut xs = self.weight.forward(a)?;
506        if let Some(bias) = &self.bias {
507            xs = xs.broadcast_add(bias)?;
508        }
509        Ok(xs)
510    }
511
512    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
513        let weight = self.weight.add_delta_w(delta)?;
514        Ok(Arc::new(Self {
515            weight,
516            bias: self.bias.clone(),
517        }))
518    }
519
520    fn dequantize_w(&self) -> Result<Tensor> {
521        self.weight.dequantize_w()
522    }
523
524    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
525        self.weight.dtype_and_device()
526    }
527
528    fn begin_track_stats(&mut self) -> Result<()> {
529        Arc::get_mut(&mut self.weight)
530            .context("Failed to get &mut to weight")?
531            .begin_track_stats()
532    }
533
534    fn end_track_stats(&self) -> Result<Tensor> {
535        self.weight.end_track_stats()
536    }
537
538    fn quantized_act_type(&self) -> Option<candle_core::DType> {
539        self.weight.quantized_act_type()
540    }
541
542    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
543        self.weight.unquant_weight_bias()
544    }
545
546    fn apply_isq(
547        self: Arc<Self>,
548        dtype: Option<crate::IsqType>,
549        device: candle_core::Device,
550        n_quantized: &std::sync::atomic::AtomicUsize,
551        imatrix_weight: Option<Vec<f32>>,
552        guard: QuantizeOntoGuard,
553    ) -> Result<Arc<dyn QuantMethod>> {
554        let weight =
555            self.weight
556                .clone()
557                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
558        let bias = match &self.bias {
559            Some(b) => {
560                let (dtype, device) = weight.dtype_and_device();
561                Some(b.to_device(&device)?.to_dtype(dtype)?)
562            }
563            None => None,
564        };
565        Ok(Arc::new(Self { weight, bias }))
566    }
567
568    fn is_distributed(&self) -> Option<DistributedKind> {
569        Some(DistributedKind::ColumnParallel)
570    }
571}
572
573impl QuantizedSerde for ColumnParallelLayer {
574    fn isq_serde_supported(&self) -> bool {
575        self.weight.isq_serde_supported()
576    }
577    fn name(&self) -> &'static str {
578        self.weight.name()
579    }
580    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
581        self.weight.serialize_with_bias(self.bias.clone())
582    }
583    fn deserialize(
584        data: std::borrow::Cow<[u8]>,
585        device: &candle_core::Device,
586        _comm: &Arc<crate::Comm>,
587        guard: QuantizeOntoGuard,
588    ) -> Result<Arc<dyn QuantMethod>>
589    where
590        Self: Sized,
591    {
592        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
593        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
594        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
595            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
596            QuantizedSerdeType::Unquant => {
597                UnquantLinear::deserialize_ext_bias(data, device, guard)?
598            }
599            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
600            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
601            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
602        };
603        Ok(Arc::new(Self { weight, bias }))
604    }
605}
606
607#[derive(Debug)]
608/// This layer has no parallelization
609pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
610
611impl ReplicatedLayer {
612    pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
613        let dev = lin.weight().device().clone();
614        let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
615        let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
616        Ok(this)
617    }
618
619    #[allow(clippy::new_ret_no_self)]
620    pub fn new(
621        in_dim: usize,
622        out_dim: usize,
623        config: &Option<QuantizedConfig>,
624        bias: bool,
625        vb: ShardedVarBuilder,
626    ) -> Result<Arc<dyn QuantMethod>> {
627        let base_vb = vb.clone();
628        let vb = if should_apply_immediate_isq(&vb) {
629            vb.set_device(Device::Cpu)
630        } else {
631            vb
632        };
633
634        let layer = if let Some(quant_conf) = &config {
635            match quant_conf {
636                QuantizedConfig::GptqAwq { .. } => {
637                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
638                }
639                QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
640                    in_dim,
641                    out_dim,
642                    quant_conf,
643                    bias,
644                    Default::default(),
645                    vb.clone(),
646                )?,
647                QuantizedConfig::Bitsandbytes { .. } => {
648                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
649                }
650                QuantizedConfig::Afq { .. } => {
651                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
652                }
653            }
654        } else {
655            // Handle the case where the layer is dummy (no tensors)
656            if !vb.contains_tensor("weight") {
657                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
658                Arc::new(layer) as Arc<dyn QuantMethod>
659            } else {
660                let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
661                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
662
663                let bias = if bias {
664                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
665                } else {
666                    None
667                };
668                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
669                    Linear::new(weight, bias),
670                ))?;
671                Arc::new(layer) as Arc<dyn QuantMethod>
672            }
673        };
674
675        let this_unquant = Arc::new(Self(layer));
676        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
677        Ok(this)
678    }
679
680    #[allow(clippy::new_ret_no_self)]
681    pub fn new_layers_matformer_indices(
682        in_dim: usize,
683        out_dim: usize,
684        kept_layers_indices: Option<&Tensor>,
685        orig_num_hidden_layers: usize,
686        config: &Option<QuantizedConfig>,
687        bias: bool,
688        vb: ShardedVarBuilder,
689    ) -> Result<Arc<dyn QuantMethod>> {
690        let base_vb = vb.clone();
691        let vb = if should_apply_immediate_isq(&vb) {
692            vb.set_device(Device::Cpu)
693        } else {
694            vb
695        };
696
697        let layer = if let Some(quant_conf) = &config {
698            if kept_layers_indices.is_some() {
699                candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
700            }
701
702            match quant_conf {
703                QuantizedConfig::GptqAwq { .. } => {
704                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
705                }
706                QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
707                    in_dim,
708                    out_dim,
709                    quant_conf,
710                    bias,
711                    Default::default(),
712                    vb.clone(),
713                )?,
714                QuantizedConfig::Bitsandbytes { .. } => {
715                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
716                }
717                QuantizedConfig::Afq { .. } => {
718                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
719                }
720            }
721        } else {
722            // Handle the case where the layer is dummy (no tensors)
723            if !vb.contains_tensor("weight") {
724                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
725                Arc::new(layer) as Arc<dyn QuantMethod>
726            } else {
727                let mut weight =
728                    vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
729
730                if let Some(kept_layers_indices) = &kept_layers_indices {
731                    let weight_reshaped = weight.reshape((
732                        orig_num_hidden_layers,
733                        weight.dim(0)? / orig_num_hidden_layers,
734                        weight.dim(1)?,
735                    ))?;
736
737                    weight = weight_reshaped
738                        .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
739                        .reshape(((), weight_reshaped.dim(D::Minus1)?))?
740                        .contiguous()?;
741                }
742
743                weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
744
745                let bias = if bias {
746                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
747                } else {
748                    None
749                };
750                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
751                    Linear::new(weight, bias),
752                ))?;
753                Arc::new(layer) as Arc<dyn QuantMethod>
754            }
755        };
756
757        let this_unquant = Arc::new(Self(layer));
758        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
759        Ok(this)
760    }
761}
762
763impl QuantMethod for ReplicatedLayer {
764    fn new(_method: QuantMethodConfig) -> Result<Self>
765    where
766        Self: Sized,
767    {
768        candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
769    }
770
771    fn forward(&self, a: &Tensor) -> Result<Tensor> {
772        self.0.forward(a)
773    }
774
775    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
776        self.0.add_delta_w(delta)
777    }
778
779    fn dequantize_w(&self) -> Result<Tensor> {
780        self.0.dequantize_w()
781    }
782
783    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
784        self.0.dtype_and_device()
785    }
786
787    fn begin_track_stats(&mut self) -> Result<()> {
788        Arc::get_mut(&mut self.0)
789            .context("Failed to get &mut to weight")?
790            .begin_track_stats()
791    }
792
793    fn end_track_stats(&self) -> Result<Tensor> {
794        self.0.end_track_stats()
795    }
796
797    fn quantized_act_type(&self) -> Option<candle_core::DType> {
798        self.0.quantized_act_type()
799    }
800
801    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
802        self.0.unquant_weight_bias()
803    }
804
805    fn apply_isq(
806        self: Arc<Self>,
807        dtype: Option<crate::IsqType>,
808        device: candle_core::Device,
809        n_quantized: &std::sync::atomic::AtomicUsize,
810        imatrix_weight: Option<Vec<f32>>,
811        guard: QuantizeOntoGuard,
812    ) -> Result<Arc<dyn QuantMethod>> {
813        self.0
814            .clone()
815            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
816    }
817
818    fn is_distributed(&self) -> Option<DistributedKind> {
819        Some(DistributedKind::Replicated)
820    }
821}
822
823impl QuantizedSerde for ReplicatedLayer {
824    fn isq_serde_supported(&self) -> bool {
825        self.0.isq_serde_supported()
826    }
827    fn name(&self) -> &'static str {
828        self.0.name()
829    }
830    fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
831        self.0.serialize()
832    }
833    fn deserialize(
834        data: std::borrow::Cow<[u8]>,
835        device: &candle_core::Device,
836        comm: &Arc<crate::Comm>,
837        guard: QuantizeOntoGuard,
838    ) -> Result<Arc<dyn QuantMethod>>
839    where
840        Self: Sized,
841    {
842        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
843        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
844        let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
845            QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
846            QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
847            QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
848            QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
849            QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
850        };
851        Ok(Arc::new(Self(deserialized)))
852    }
853}
854
855#[derive(Debug)]
856pub struct PackedExperts {
857    pub gate_proj: Vec<Arc<dyn QuantMethod>>,
858    pub up_proj: Vec<Arc<dyn QuantMethod>>,
859    pub down_proj: Vec<Arc<dyn QuantMethod>>,
860}
861
862impl PackedExperts {
863    /// Note: we only support AFQ and unquantized here because they are the only ones that support indexed.
864    #[allow(clippy::too_many_arguments)]
865    pub fn new(
866        num_local_experts: usize,
867        hidden_size: usize,
868        intermediate_size: usize,
869        config: &Option<QuantizedConfig>,
870        bias: bool,
871        comm: &Arc<crate::Comm>,
872        vb: ShardedVarBuilder,
873    ) -> Result<Self> {
874        if bias {
875            candle_core::bail!("PackedExperts does not support bias.");
876        }
877
878        let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
879            // GPTQ and BNB do not support tensor parallelism
880            if comm.world_size() != 1 {
881                candle_core::bail!(
882                    "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
883                    comm.world_size()
884                );
885            }
886
887            match quant_conf {
888                QuantizedConfig::Afq { .. } => {
889                    if !vb.contains_tensor("gate_up_proj")
890                        || !vb.contains_tensor("gate_up_proj.weight")
891                    {
892                        candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
893                    }
894
895                    let base_vb = vb.clone();
896
897                    let vb_gate_proj = if should_apply_immediate_isq(&vb) {
898                        vb.pp("gate_proj").set_device(Device::Cpu)
899                    } else {
900                        vb.pp("gate_proj")
901                    };
902                    let vb_up_proj = if should_apply_immediate_isq(&vb) {
903                        vb.pp("up_proj").set_device(Device::Cpu)
904                    } else {
905                        vb.pp("up_proj")
906                    };
907                    let vb_down_proj = if should_apply_immediate_isq(&vb) {
908                        vb.pp("down_proj").set_device(Device::Cpu)
909                    } else {
910                        vb.pp("down_proj")
911                    };
912                    let mut gate_proj = AfqLayer::afq_packed_linear_b(
913                        num_local_experts,
914                        hidden_size,
915                        intermediate_size,
916                        quant_conf,
917                        bias,
918                        vb_gate_proj,
919                    )?;
920                    let mut up_proj = AfqLayer::afq_packed_linear_b(
921                        num_local_experts,
922                        hidden_size,
923                        intermediate_size,
924                        quant_conf,
925                        bias,
926                        vb_up_proj,
927                    )?;
928                    let mut down_proj = AfqLayer::afq_packed_linear_b(
929                        num_local_experts,
930                        intermediate_size,
931                        hidden_size,
932                        quant_conf,
933                        bias,
934                        vb_down_proj,
935                    )?;
936
937                    gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
938                    up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
939                    down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
940
941                    (vec![gate_proj], vec![up_proj], vec![down_proj])
942                }
943                _ => candle_core::bail!(
944                    "PackedExperts with quantization config only allows AFQ quantization"
945                ),
946            }
947        } else if !vb.contains_tensor("gate_up_proj") {
948            // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
949            let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
950            let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
951            let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
952            for _ in 0..num_local_experts {
953                gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
954                us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
955                ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
956            }
957            (gs, us, ds)
958        } else {
959            // Parallelized like:
960            // Each gpu holds all experts.
961            // Gate/Up proj is parallelized on dim 2 (column)
962            // Down proj is parallelized on dim 1 (row)
963            // All reduce at the end.
964
965            // Handle the case where the layer is dummy (no tensors)
966            let gate_up_block_size = intermediate_size / comm.world_size();
967            let gate_up_start = gate_up_block_size * comm.rank();
968
969            // Gate is right before Up in the gate_up
970            let shard_gate = Shard::Offset {
971                dim: 2,
972                offset: gate_up_start,
973                len: gate_up_block_size,
974            };
975            let shard_up = Shard::Offset {
976                dim: 2,
977                offset: intermediate_size + gate_up_start,
978                len: gate_up_block_size,
979            };
980            let shard_down = Shard::Simple {
981                dim: 1,
982                rank: comm.rank(),
983                world_size: comm.world_size(),
984            };
985
986            let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
987                vb.pp("gate_up_proj").set_device(Device::Cpu)
988            } else {
989                vb.pp("gate_up_proj")
990            };
991            let vb_down_proj = if should_apply_immediate_isq(&vb) {
992                vb.pp("down_proj").set_device(Device::Cpu)
993            } else {
994                vb.pp("down_proj")
995            };
996
997            let gate_proj = vb
998                .get_with_hints(
999                    (num_local_experts, hidden_size, intermediate_size * 2),
1000                    "gate_up_proj",
1001                    shard_gate,
1002                )?
1003                .t()?
1004                .contiguous()?;
1005            let up_proj = vb
1006                .get_with_hints(
1007                    (num_local_experts, hidden_size, intermediate_size * 2),
1008                    "gate_up_proj",
1009                    shard_up,
1010                )?
1011                .t()?
1012                .contiguous()?;
1013            let down_proj = vb
1014                .get_with_hints(
1015                    (num_local_experts, intermediate_size, hidden_size),
1016                    "down_proj",
1017                    shard_down,
1018                )?
1019                .t()?
1020                .contiguous()?;
1021
1022            let gc = gate_proj.chunk(num_local_experts, 0)?;
1023            let uc = up_proj.chunk(num_local_experts, 0)?;
1024            let dc = down_proj.chunk(num_local_experts, 0)?;
1025            drop((gate_proj, up_proj, down_proj));
1026
1027            let mut gs = Vec::new();
1028            let mut us = Vec::new();
1029            let mut ds = Vec::new();
1030            for ((mut gate_proj, mut up_proj), mut down_proj) in
1031                gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
1032            {
1033                gate_proj = gate_proj.squeeze(0)?;
1034                up_proj = up_proj.squeeze(0)?;
1035                down_proj = down_proj.squeeze(0)?;
1036                let gate_proj = merge_lora_weights(
1037                    &vb,
1038                    gate_proj,
1039                    hidden_size,
1040                    intermediate_size * 2,
1041                    shard_gate,
1042                )?;
1043                let up_proj =
1044                    merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1045                let down_proj =
1046                    merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1047
1048                let mut gate_proj: Arc<dyn QuantMethod> =
1049                    Arc::new(<UnquantLinear as QuantMethod>::new(
1050                        QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1051                    )?);
1052                gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1053                let mut up_proj: Arc<dyn QuantMethod> =
1054                    Arc::new(<UnquantLinear as QuantMethod>::new(
1055                        QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1056                    )?);
1057                up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1058                let mut down_proj: Arc<dyn QuantMethod> =
1059                    Arc::new(<UnquantLinear as QuantMethod>::new(
1060                        QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1061                    )?);
1062                down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1063                gs.push(gate_proj);
1064                us.push(up_proj);
1065                ds.push(down_proj);
1066            }
1067            (gs, us, ds)
1068        };
1069
1070        Ok(Self {
1071            gate_proj,
1072            up_proj,
1073            down_proj,
1074        })
1075    }
1076}
1077
1078pub struct FusedExperts {
1079    pub fused_gate_proj: Arc<dyn QuantMethod>,
1080    pub fused_up_proj: Arc<dyn QuantMethod>,
1081    pub fused_down_proj: Arc<dyn QuantMethod>,
1082}
1083
1084impl FusedExperts {
1085    pub fn new(
1086        hidden_size: usize,
1087        moe_intermediate_size: usize,
1088        num_experts: usize,
1089        quantization_config: &Option<QuantizedConfig>,
1090        vb: ShardedVarBuilder,
1091    ) -> Result<Self> {
1092        if !vb.device().is_metal() {
1093            candle_core::bail!("FastMoeMlp requires Metal.");
1094        }
1095
1096        let (fused_gate_proj, fused_up_proj, fused_down_proj) =
1097            if matches!(&quantization_config, Some(QuantizedConfig::Afq { .. })) {
1098                let quantization_config = quantization_config.as_ref().unwrap();
1099
1100                let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1101                    num_experts,
1102                    hidden_size,
1103                    moe_intermediate_size,
1104                    quantization_config,
1105                    false,
1106                    vb.pp("switch_mlp.gate_proj"),
1107                )?;
1108                let fused_up_proj = AfqLayer::afq_packed_linear_b(
1109                    num_experts,
1110                    hidden_size,
1111                    moe_intermediate_size,
1112                    quantization_config,
1113                    false,
1114                    vb.pp("switch_mlp.up_proj"),
1115                )?;
1116                let fused_down_proj = AfqLayer::afq_packed_linear_b(
1117                    num_experts,
1118                    moe_intermediate_size,
1119                    hidden_size,
1120                    quantization_config,
1121                    false,
1122                    vb.pp("switch_mlp.down_proj"),
1123                )?;
1124
1125                (fused_gate_proj, fused_up_proj, fused_down_proj)
1126            } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1127                let experts_vb = vb.pp("experts");
1128                let mut gate_proj_vec = Vec::new();
1129                let mut up_proj_vec = Vec::new();
1130                let mut down_proj_vec = Vec::new();
1131                for i in 0..num_experts {
1132                    let vb = experts_vb.pp(i);
1133
1134                    let gate_proj = crate::linear_no_bias(
1135                        hidden_size,
1136                        moe_intermediate_size,
1137                        quantization_config,
1138                        vb.pp("gate_proj.weight"),
1139                    )?;
1140                    let up_proj = crate::linear_no_bias(
1141                        hidden_size,
1142                        moe_intermediate_size,
1143                        quantization_config,
1144                        vb.pp("up_proj.weight"),
1145                    )?;
1146                    let down_proj = crate::linear_no_bias(
1147                        moe_intermediate_size,
1148                        hidden_size,
1149                        quantization_config,
1150                        vb.pp("down_proj.weight"),
1151                    )?;
1152
1153                    gate_proj_vec.push(gate_proj.dequantize_w()?);
1154                    up_proj_vec.push(up_proj.dequantize_w()?);
1155                    down_proj_vec.push(down_proj.dequantize_w()?);
1156                }
1157
1158                let mut gate_proj: Arc<dyn QuantMethod> =
1159                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1160                        Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1161                    ))?);
1162                let mut up_proj: Arc<dyn QuantMethod> =
1163                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1164                        Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
1165                    ))?);
1166                let mut down_proj: Arc<dyn QuantMethod> =
1167                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1168                        Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1169                    ))?);
1170                gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1171                up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1172                down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1173
1174                (gate_proj, up_proj, down_proj)
1175            } else {
1176                let experts_vb = vb.pp("experts");
1177                let mut gate_proj_vec = Vec::new();
1178                let mut up_proj_vec = Vec::new();
1179                let mut down_proj_vec = Vec::new();
1180                for i in 0..num_experts {
1181                    let vb = experts_vb.pp(i);
1182                    let gate_proj =
1183                        vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1184                    let up_proj = vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1185                    let down_proj =
1186                        vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1187
1188                    gate_proj_vec.push(gate_proj);
1189                    up_proj_vec.push(up_proj);
1190                    down_proj_vec.push(down_proj);
1191                }
1192
1193                let mut gate_proj: Arc<dyn QuantMethod> =
1194                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1195                        Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1196                    ))?);
1197                let mut up_proj: Arc<dyn QuantMethod> =
1198                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1199                        Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
1200                    ))?);
1201                let mut down_proj: Arc<dyn QuantMethod> =
1202                    Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1203                        Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1204                    ))?);
1205                gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1206                up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1207                down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1208
1209                (gate_proj, up_proj, down_proj)
1210            };
1211
1212        Ok(Self {
1213            fused_gate_proj,
1214            fused_up_proj,
1215            fused_down_proj,
1216        })
1217    }
1218}
1219
1220/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
1221pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1222    if comm.world_size() == 1 {
1223        return Shard::default();
1224    }
1225
1226    // Tensor parallelism case
1227
1228    // We may need to replicate the kv heads
1229    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1230        comm.world_size() / total_num_kv_heads
1231    } else {
1232        return Shard::Simple {
1233            dim: 0,
1234            rank: comm.rank(),
1235            world_size: comm.world_size(),
1236        };
1237    };
1238
1239    let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1240    let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1241    Shard::Offset {
1242        dim: 0,
1243        offset: kv_shard_id * head_dim,
1244        len: head_dim,
1245    }
1246}
1247
1248/// Compute the number of KV groups, taking into account KV head replication.
1249pub fn compute_n_kv_groups(
1250    total_num_kv_heads: usize,
1251    num_attention_heads: usize,
1252    comm: &Comm,
1253) -> usize {
1254    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1255        comm.world_size() / total_num_kv_heads
1256    } else {
1257        1
1258    };
1259    if kv_replicate != 0 {
1260        (num_attention_heads / total_num_kv_heads) / kv_replicate
1261    } else {
1262        num_attention_heads / total_num_kv_heads
1263    }
1264}