1use std::sync::Arc;
2
3use candle_core::{Context, Result, Tensor};
4use candle_nn::Linear;
5
6use crate::{
7 blockwise_fp8::blockwise_fp8_linear_b, distributed, gptq::gptq_linear,
8 lora::merge_lora_weights, AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear,
9 GgufMatMul, HqqLayer, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
10 QuantizedSerde, QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
11};
12
13use super::{Comm, SumAllReduce};
14
15fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
16 Shard::Simple {
17 dim,
18 rank,
19 world_size,
20 }
21}
22
23#[derive(Debug)]
26pub struct RowParallelLayer {
27 weight: Arc<dyn QuantMethod>,
28 bias: Option<Tensor>,
29 all_reduce: distributed::SumAllReduce,
30}
31
32impl RowParallelLayer {
33 #[allow(clippy::new_ret_no_self)]
34 pub fn new(
35 in_dim: usize,
36 out_dim: usize,
37 config: &Option<QuantizedConfig>,
38 bias: bool,
39 comm: &Arc<crate::Comm>,
40 vb: ShardedVarBuilder,
41 ) -> Result<Arc<dyn QuantMethod>> {
42 let rank = comm.rank();
43 let world_size = comm.world_size();
44 let shard = shard(1, rank, world_size);
45
46 let weight = if let Some(quant_conf) = &config {
47 if matches!(
49 quant_conf,
50 QuantizedConfig::Gptq { .. }
51 | QuantizedConfig::Bitsandbytes { .. }
52 | QuantizedConfig::Afq { .. }
53 ) && comm.world_size() != 1
54 {
55 candle_core::bail!(
56 "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
57 comm.world_size()
58 );
59 }
60
61 match quant_conf {
62 QuantizedConfig::Gptq { .. } => {
63 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
64 }
65 QuantizedConfig::Fp8 { .. } => {
66 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
68 }
69 QuantizedConfig::Bitsandbytes { .. } => {
70 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
71 }
72 QuantizedConfig::Afq { .. } => {
73 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
74 }
75 }
76 } else {
77 if !vb.contains_tensor("weight") {
79 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
80 Arc::new(layer) as Arc<dyn QuantMethod>
81 } else {
82 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
83 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
84
85 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
86 Linear::new(weight, None),
87 ))?;
88 Arc::new(layer) as Arc<dyn QuantMethod>
89 }
90 };
91
92 let bias = if bias && vb.contains_tensor("bias") {
94 Some(vb.get((out_dim,), "bias")?)
95 } else {
96 None
97 };
98
99 Ok(Arc::new(Self {
100 weight,
101 bias,
102 all_reduce: distributed::SumAllReduce::new(comm),
103 }))
104 }
105}
106
107impl QuantMethod for RowParallelLayer {
108 fn new(_method: QuantMethodConfig) -> Result<Self>
109 where
110 Self: Sized,
111 {
112 candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
113 }
114
115 fn forward(&self, a: &Tensor) -> Result<Tensor> {
116 let mut xs = self.weight.forward(a)?;
117 xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
118 if let Some(bias) = &self.bias {
119 xs = xs.broadcast_add(bias)?;
120 }
121 Ok(xs)
122 }
123
124 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
125 let weight = self.weight.add_delta_w(delta)?;
126 Ok(Arc::new(Self {
127 weight,
128 bias: self.bias.clone(),
129 all_reduce: self.all_reduce.clone(),
130 }))
131 }
132
133 fn dequantize_w(&self) -> Result<Tensor> {
134 self.weight.dequantize_w()
135 }
136
137 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
138 self.weight.dtype_and_device()
139 }
140
141 fn begin_track_stats(&mut self) -> Result<()> {
142 Arc::get_mut(&mut self.weight)
143 .context("Failed to get &mut to weight")?
144 .begin_track_stats()
145 }
146
147 fn end_track_stats(&self) -> Result<Tensor> {
148 self.weight.end_track_stats()
149 }
150
151 fn quantized_act_type(&self) -> Option<candle_core::DType> {
152 self.weight.quantized_act_type()
153 }
154
155 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
156 self.weight.unquant_weight_bias()
157 }
158
159 fn apply_isq(
160 self: Arc<Self>,
161 dtype: Option<crate::IsqType>,
162 device: candle_core::Device,
163 n_quantized: &std::sync::atomic::AtomicUsize,
164 imatrix_weight: Option<Vec<f32>>,
165 guard: QuantizeOntoGuard,
166 ) -> Result<Arc<dyn QuantMethod>> {
167 let weight =
168 self.weight
169 .clone()
170 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
171 let bias = match &self.bias {
172 Some(b) => {
173 let (dtype, device) = weight.dtype_and_device();
174 Some(b.to_device(&device)?.to_dtype(dtype)?)
175 }
176 None => None,
177 };
178 Ok(Arc::new(Self {
179 weight,
180 bias,
181 all_reduce: self.all_reduce.clone(),
182 }))
183 }
184
185 fn is_distributed(&self) -> Option<DistributedKind> {
186 Some(DistributedKind::RowParallel)
187 }
188}
189
190impl QuantizedSerde for RowParallelLayer {
191 fn isq_serde_supported(&self) -> bool {
192 self.weight.isq_serde_supported()
193 }
194 fn name(&self) -> &'static str {
195 self.weight.name()
196 }
197 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
198 self.weight.serialize_with_bias(self.bias.clone())
199 }
200 fn deserialize(
201 data: std::borrow::Cow<[u8]>,
202 device: &candle_core::Device,
203 comm: &Arc<crate::Comm>,
204 guard: QuantizeOntoGuard,
205 ) -> Result<Arc<dyn QuantMethod>>
206 where
207 Self: Sized,
208 {
209 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
211 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
212 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
213 QuantizedSerdeType::Unquant => {
214 UnquantLinear::deserialize_ext_bias(data, device, guard)?
215 }
216 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
217 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
218 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
219 };
220 Ok(Arc::new(Self {
221 weight,
222 bias,
223 all_reduce: SumAllReduce::new(comm),
224 }))
225 }
226}
227
228#[derive(Debug)]
229pub struct ColumnParallelLayer {
232 weight: Arc<dyn QuantMethod>,
233 bias: Option<Tensor>,
234}
235
236impl ColumnParallelLayer {
237 #[allow(clippy::new_ret_no_self)]
238 pub fn new_with_shard(
239 in_dim: usize,
240 out_dim: usize,
241 config: &Option<QuantizedConfig>,
242 bias: bool,
243 comm: &Arc<crate::Comm>,
244 shard: Shard,
245 vb: ShardedVarBuilder,
246 ) -> Result<Arc<dyn QuantMethod>> {
247 let weight = if let Some(quant_conf) = &config {
248 if matches!(
250 quant_conf,
251 QuantizedConfig::Gptq { .. }
252 | QuantizedConfig::Bitsandbytes { .. }
253 | QuantizedConfig::Afq { .. }
254 ) && comm.world_size() != 1
255 {
256 candle_core::bail!(
257 "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
258 comm.world_size()
259 );
260 }
261
262 match quant_conf {
263 QuantizedConfig::Gptq { .. } => {
264 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
265 }
266 QuantizedConfig::Fp8 { .. } => {
267 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
269 }
270 QuantizedConfig::Bitsandbytes { .. } => {
271 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
272 }
273 QuantizedConfig::Afq { .. } => {
274 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
275 }
276 }
277 } else {
278 if !vb.contains_tensor("weight") {
280 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
281 Arc::new(layer) as Arc<dyn QuantMethod>
282 } else {
283 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
284 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
285
286 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
287 Linear::new(weight, None),
288 ))?;
289 Arc::new(layer) as Arc<dyn QuantMethod>
290 }
291 };
292
293 let bias = if bias && vb.contains_tensor("bias") {
295 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
296 } else {
297 None
298 };
299
300 Ok(Arc::new(Self { weight, bias }))
301 }
302
303 #[allow(clippy::new_ret_no_self)]
304 pub fn new(
305 in_dim: usize,
306 out_dim: usize,
307 config: &Option<QuantizedConfig>,
308 bias: bool,
309 comm: &Arc<crate::Comm>,
310 vb: ShardedVarBuilder,
311 ) -> Result<Arc<dyn QuantMethod>> {
312 let rank = comm.rank();
313 let world_size = comm.world_size();
314 let shard = shard(0, rank, world_size);
315
316 Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
317 }
318}
319
320impl QuantMethod for ColumnParallelLayer {
321 fn new(_method: QuantMethodConfig) -> Result<Self>
322 where
323 Self: Sized,
324 {
325 candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
326 }
327
328 fn forward(&self, a: &Tensor) -> Result<Tensor> {
329 let mut xs = self.weight.forward(a)?;
330 if let Some(bias) = &self.bias {
331 xs = xs.broadcast_add(bias)?;
332 }
333 Ok(xs)
334 }
335
336 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
337 let weight = self.weight.add_delta_w(delta)?;
338 Ok(Arc::new(Self {
339 weight,
340 bias: self.bias.clone(),
341 }))
342 }
343
344 fn dequantize_w(&self) -> Result<Tensor> {
345 self.weight.dequantize_w()
346 }
347
348 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
349 self.weight.dtype_and_device()
350 }
351
352 fn begin_track_stats(&mut self) -> Result<()> {
353 Arc::get_mut(&mut self.weight)
354 .context("Failed to get &mut to weight")?
355 .begin_track_stats()
356 }
357
358 fn end_track_stats(&self) -> Result<Tensor> {
359 self.weight.end_track_stats()
360 }
361
362 fn quantized_act_type(&self) -> Option<candle_core::DType> {
363 self.weight.quantized_act_type()
364 }
365
366 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
367 self.weight.unquant_weight_bias()
368 }
369
370 fn apply_isq(
371 self: Arc<Self>,
372 dtype: Option<crate::IsqType>,
373 device: candle_core::Device,
374 n_quantized: &std::sync::atomic::AtomicUsize,
375 imatrix_weight: Option<Vec<f32>>,
376 guard: QuantizeOntoGuard,
377 ) -> Result<Arc<dyn QuantMethod>> {
378 let weight =
379 self.weight
380 .clone()
381 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
382 let bias = match &self.bias {
383 Some(b) => {
384 let (dtype, device) = weight.dtype_and_device();
385 Some(b.to_device(&device)?.to_dtype(dtype)?)
386 }
387 None => None,
388 };
389 Ok(Arc::new(Self { weight, bias }))
390 }
391
392 fn is_distributed(&self) -> Option<DistributedKind> {
393 Some(DistributedKind::ColumnParallel)
394 }
395}
396
397impl QuantizedSerde for ColumnParallelLayer {
398 fn isq_serde_supported(&self) -> bool {
399 self.weight.isq_serde_supported()
400 }
401 fn name(&self) -> &'static str {
402 self.weight.name()
403 }
404 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
405 self.weight.serialize_with_bias(self.bias.clone())
406 }
407 fn deserialize(
408 data: std::borrow::Cow<[u8]>,
409 device: &candle_core::Device,
410 _comm: &Arc<crate::Comm>,
411 guard: QuantizeOntoGuard,
412 ) -> Result<Arc<dyn QuantMethod>>
413 where
414 Self: Sized,
415 {
416 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
418 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
419 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
420 QuantizedSerdeType::Unquant => {
421 UnquantLinear::deserialize_ext_bias(data, device, guard)?
422 }
423 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
424 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
425 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
426 };
427 Ok(Arc::new(Self { weight, bias }))
428 }
429}
430
431#[derive(Debug)]
432pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
434
435impl ReplicatedLayer {
436 pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
437 Ok(Arc::new(UnquantLinear::new(
438 QuantMethodConfig::Unquantized(lin),
439 )?))
440 }
441
442 #[allow(clippy::new_ret_no_self)]
443 pub fn new(
444 in_dim: usize,
445 out_dim: usize,
446 config: &Option<QuantizedConfig>,
447 bias: bool,
448 vb: ShardedVarBuilder,
449 ) -> Result<Arc<dyn QuantMethod>> {
450 let layer = if let Some(quant_conf) = &config {
451 match quant_conf {
452 QuantizedConfig::Gptq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
453 QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
454 in_dim,
455 out_dim,
456 quant_conf,
457 bias,
458 Default::default(),
459 vb,
460 )?,
461 QuantizedConfig::Bitsandbytes { .. } => {
462 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb)?) as Arc<_>
463 }
464 QuantizedConfig::Afq { .. } => {
465 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
466 }
467 }
468 } else {
469 if !vb.contains_tensor("weight") {
471 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
472 Arc::new(layer) as Arc<dyn QuantMethod>
473 } else {
474 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
475 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
476
477 let bias = if bias {
478 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
479 } else {
480 None
481 };
482 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
483 Linear::new(weight, bias),
484 ))?;
485 Arc::new(layer) as Arc<dyn QuantMethod>
486 }
487 };
488
489 Ok(Arc::new(Self(layer)))
490 }
491}
492
493impl QuantMethod for ReplicatedLayer {
494 fn new(_method: QuantMethodConfig) -> Result<Self>
495 where
496 Self: Sized,
497 {
498 candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
499 }
500
501 fn forward(&self, a: &Tensor) -> Result<Tensor> {
502 self.0.forward(a)
503 }
504
505 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
506 self.0.add_delta_w(delta)
507 }
508
509 fn dequantize_w(&self) -> Result<Tensor> {
510 self.0.dequantize_w()
511 }
512
513 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
514 self.0.dtype_and_device()
515 }
516
517 fn begin_track_stats(&mut self) -> Result<()> {
518 Arc::get_mut(&mut self.0)
519 .context("Failed to get &mut to weight")?
520 .begin_track_stats()
521 }
522
523 fn end_track_stats(&self) -> Result<Tensor> {
524 self.0.end_track_stats()
525 }
526
527 fn quantized_act_type(&self) -> Option<candle_core::DType> {
528 self.0.quantized_act_type()
529 }
530
531 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
532 self.0.unquant_weight_bias()
533 }
534
535 fn apply_isq(
536 self: Arc<Self>,
537 dtype: Option<crate::IsqType>,
538 device: candle_core::Device,
539 n_quantized: &std::sync::atomic::AtomicUsize,
540 imatrix_weight: Option<Vec<f32>>,
541 guard: QuantizeOntoGuard,
542 ) -> Result<Arc<dyn QuantMethod>> {
543 self.0
544 .clone()
545 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
546 }
547
548 fn is_distributed(&self) -> Option<DistributedKind> {
549 Some(DistributedKind::Replicated)
550 }
551}
552
553impl QuantizedSerde for ReplicatedLayer {
554 fn isq_serde_supported(&self) -> bool {
555 self.0.isq_serde_supported()
556 }
557 fn name(&self) -> &'static str {
558 self.0.name()
559 }
560 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
561 self.0.serialize()
562 }
563 fn deserialize(
564 data: std::borrow::Cow<[u8]>,
565 device: &candle_core::Device,
566 comm: &Arc<crate::Comm>,
567 guard: QuantizeOntoGuard,
568 ) -> Result<Arc<dyn QuantMethod>>
569 where
570 Self: Sized,
571 {
572 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
574 let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
575 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
576 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
577 QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
578 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
579 QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
580 };
581 Ok(Arc::new(Self(deserialized)))
582 }
583}
584
585#[derive(Debug)]
586pub struct PackedExperts {
587 pub gate_proj: Vec<Arc<dyn QuantMethod>>,
588 pub up_proj: Vec<Arc<dyn QuantMethod>>,
589 pub down_proj: Vec<Arc<dyn QuantMethod>>,
590}
591
592impl PackedExperts {
593 #[allow(clippy::too_many_arguments)]
595 pub fn new(
596 num_local_experts: usize,
597 hidden_size: usize,
598 intermediate_size: usize,
599 config: &Option<QuantizedConfig>,
600 bias: bool,
601 comm: &Arc<crate::Comm>,
602 vb: ShardedVarBuilder,
603 ) -> Result<Self> {
604 if bias {
605 candle_core::bail!("PackedExperts does not support bias.");
606 }
607
608 let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
609 if comm.world_size() != 1 {
611 candle_core::bail!(
612 "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
613 comm.world_size()
614 );
615 }
616
617 match quant_conf {
618 QuantizedConfig::Afq { .. } => {
619 if !vb.contains_tensor("gate_up_proj")
620 || !vb.contains_tensor("gate_up_proj.weight")
621 {
622 candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
623 }
624
625 (
626 vec![AfqLayer::afq_packed_linear_b(
627 num_local_experts,
628 hidden_size,
629 intermediate_size,
630 quant_conf,
631 bias,
632 vb.pp("gate_proj"),
633 )?],
634 vec![AfqLayer::afq_packed_linear_b(
635 num_local_experts,
636 hidden_size,
637 intermediate_size,
638 quant_conf,
639 bias,
640 vb.pp("up_proj"),
641 )?],
642 vec![AfqLayer::afq_packed_linear_b(
643 num_local_experts,
644 intermediate_size,
645 hidden_size,
646 quant_conf,
647 bias,
648 vb.pp("down_proj"),
649 )?],
650 )
651 }
652 _ => candle_core::bail!(
653 "PackedExperts with quantization config only allows AFQ quantization"
654 ),
655 }
656 } else if !vb.contains_tensor("down_proj") {
657 let mut gs = Vec::new();
659 let mut us = Vec::new();
660 let mut ds = Vec::new();
661 for _ in 0..num_local_experts {
662 let gate_proj = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
663 let up_proj = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
664 let down_proj = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
665 gs.push(Arc::new(gate_proj) as Arc<dyn QuantMethod>);
666 us.push(Arc::new(up_proj) as Arc<dyn QuantMethod>);
667 ds.push(Arc::new(down_proj) as Arc<dyn QuantMethod>);
668 }
669 (gs, us, ds)
670 } else {
671 let gate_up_block_size = intermediate_size / comm.world_size();
679 let gate_up_start = gate_up_block_size * comm.rank();
680
681 let shard_gate = Shard::Offset {
683 dim: 2,
684 offset: gate_up_start,
685 len: gate_up_block_size,
686 };
687 let shard_up = Shard::Offset {
688 dim: 2,
689 offset: intermediate_size + gate_up_start,
690 len: gate_up_block_size,
691 };
692 let shard_down = Shard::Simple {
693 dim: 1,
694 rank: comm.rank(),
695 world_size: comm.world_size(),
696 };
697
698 let gate_proj = vb
699 .get_with_hints(
700 (num_local_experts, hidden_size, intermediate_size * 2),
701 "gate_up_proj",
702 shard_gate,
703 )?
704 .t()?
705 .contiguous()?;
706 let up_proj = vb
707 .get_with_hints(
708 (num_local_experts, hidden_size, intermediate_size * 2),
709 "gate_up_proj",
710 shard_up,
711 )?
712 .t()?
713 .contiguous()?;
714 let down_proj = vb
715 .get_with_hints(
716 (num_local_experts, intermediate_size, hidden_size),
717 "down_proj",
718 shard_down,
719 )?
720 .t()?
721 .contiguous()?;
722
723 let mut gs = Vec::new();
724 let mut us = Vec::new();
725 let mut ds = Vec::new();
726 for ((mut gate_proj, mut up_proj), mut down_proj) in gate_proj
727 .chunk(num_local_experts, 0)?
728 .into_iter()
729 .zip(up_proj.chunk(num_local_experts, 0)?)
730 .zip(down_proj.chunk(num_local_experts, 0)?)
731 {
732 gate_proj = gate_proj.squeeze(0)?;
733 up_proj = up_proj.squeeze(0)?;
734 down_proj = down_proj.squeeze(0)?;
735 let gate_proj = merge_lora_weights(
736 &vb,
737 gate_proj,
738 hidden_size,
739 intermediate_size * 2,
740 shard_gate,
741 )?;
742 let up_proj =
743 merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
744 let down_proj =
745 merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
746
747 let gate_proj = <UnquantLinear as QuantMethod>::new(
748 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
749 )?;
750 let up_proj = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
751 Linear::new(up_proj, None),
752 ))?;
753 let down_proj = <UnquantLinear as QuantMethod>::new(
754 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
755 )?;
756 gs.push(Arc::new(gate_proj) as Arc<dyn QuantMethod>);
757 us.push(Arc::new(up_proj) as Arc<dyn QuantMethod>);
758 ds.push(Arc::new(down_proj) as Arc<dyn QuantMethod>);
759 }
760 (gs, us, ds)
761 };
762
763 Ok(Self {
764 gate_proj,
765 up_proj,
766 down_proj,
767 })
768 }
769}
770
771pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
773 if comm.world_size() == 1 {
774 return Shard::default();
775 }
776
777 let kv_replicate = if comm.world_size() > total_num_kv_heads {
781 comm.world_size() / total_num_kv_heads
782 } else {
783 return Shard::Simple {
784 dim: 0,
785 rank: comm.rank(),
786 world_size: comm.world_size(),
787 };
788 };
789
790 let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
791 let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
792 Shard::Offset {
793 dim: 0,
794 offset: kv_shard_id * head_dim,
795 len: head_dim,
796 }
797}
798
799pub fn compute_n_kv_groups(
801 total_num_kv_heads: usize,
802 num_attention_heads: usize,
803 comm: &Comm,
804) -> usize {
805 let kv_replicate = if comm.world_size() > total_num_kv_heads {
806 comm.world_size() / total_num_kv_heads
807 } else {
808 1
809 };
810 if kv_replicate != 0 {
811 (num_attention_heads / total_num_kv_heads) / kv_replicate
812 } else {
813 num_attention_heads / total_num_kv_heads
814 }
815}