1use std::sync::Arc;
2
3use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
4use candle_nn::Linear;
5
6use crate::{
7 blockwise_fp8::blockwise_fp8_linear_b,
8 distributed,
9 gptq::gptq_linear,
10 lora::merge_lora_weights,
11 should_apply_immediate_isq,
12 utils::isq::{apply_immediate_isq, apply_immediate_isq_always},
13 AfqLayer, BnbLinear, DistributedKind, DummyLayer, FP8Linear, GgufMatMul, HqqLayer, QuantMethod,
14 QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, QuantizedSerdeType,
15 Shard, ShardedVarBuilder, UnquantLinear,
16};
17
18use super::{Comm, SumAllReduce};
19
20fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
21 Shard::Simple {
22 dim,
23 rank,
24 world_size,
25 }
26}
27
28#[derive(Debug)]
31pub struct RowParallelLayer {
32 weight: Arc<dyn QuantMethod>,
33 bias: Option<Tensor>,
34 all_reduce: distributed::SumAllReduce,
35}
36
37impl RowParallelLayer {
38 #[allow(clippy::new_ret_no_self)]
39 pub fn new(
40 in_dim: usize,
41 out_dim: usize,
42 config: &Option<QuantizedConfig>,
43 bias: bool,
44 comm: &Arc<crate::Comm>,
45 vb: ShardedVarBuilder,
46 ) -> Result<Arc<dyn QuantMethod>> {
47 let rank = comm.rank();
48 let world_size = comm.world_size();
49 let shard = shard(1, rank, world_size);
50
51 let base_vb = vb.clone();
52 let vb = if should_apply_immediate_isq(&vb) {
53 vb.set_device(Device::Cpu)
54 } else {
55 vb
56 };
57
58 let weight = if let Some(quant_conf) = &config {
59 if matches!(
61 quant_conf,
62 QuantizedConfig::GptqAwq { .. }
63 | QuantizedConfig::Bitsandbytes { .. }
64 | QuantizedConfig::Afq { .. }
65 ) && comm.world_size() != 1
66 {
67 candle_core::bail!(
68 "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
69 comm.world_size()
70 );
71 }
72
73 match quant_conf {
74 QuantizedConfig::GptqAwq { .. } => {
75 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
76 }
77 QuantizedConfig::Fp8 { .. } => {
78 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
80 }
81 QuantizedConfig::Bitsandbytes { .. } => {
82 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
83 }
84 QuantizedConfig::Afq { .. } => {
85 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
86 }
87 }
88 } else {
89 if !vb.contains_tensor("weight") {
91 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
92 Arc::new(layer) as Arc<dyn QuantMethod>
93 } else {
94 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
95 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
96
97 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
98 Linear::new(weight, None),
99 ))?;
100 Arc::new(layer) as Arc<dyn QuantMethod>
101 }
102 };
103
104 let bias = if bias && vb.contains_tensor("bias") {
106 Some(vb.get((out_dim,), "bias")?)
107 } else {
108 None
109 };
110
111 let this_unquant = Arc::new(Self {
112 weight,
113 bias,
114 all_reduce: distributed::SumAllReduce::new(comm),
115 });
116 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
117 Ok(this)
118 }
119
120 #[allow(clippy::new_ret_no_self)]
121 pub fn new_matformer(
122 in_dim: usize,
123 out_dim: usize,
124 orig_intermediate_size: usize,
125 config: &Option<QuantizedConfig>,
126 bias: bool,
127 comm: &Arc<crate::Comm>,
128 vb: ShardedVarBuilder,
129 ) -> Result<Arc<dyn QuantMethod>> {
130 let rank = comm.rank();
131 let world_size = comm.world_size();
132 let shard = shard(1, rank, world_size);
133
134 let base_vb = vb.clone();
135 let vb = if should_apply_immediate_isq(&vb) {
136 vb.set_device(Device::Cpu)
137 } else {
138 vb
139 };
140
141 if config.is_some() {
142 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
143 }
144
145 let weight = if !vb.contains_tensor("weight") {
147 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
148 Arc::new(layer) as Arc<dyn QuantMethod>
149 } else {
150 let weight = vb
151 .get_with_hints(
152 (out_dim, orig_intermediate_size),
153 "weight",
154 Default::default(),
155 )?
156 .i((.., ..in_dim))?
157 .contiguous()?;
158
159 let weight = shard.apply_to(&weight)?;
160 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
161
162 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
163 Linear::new(weight, None),
164 ))?;
165 Arc::new(layer) as Arc<dyn QuantMethod>
166 };
167
168 let bias = if bias && vb.contains_tensor("bias") {
170 Some(vb.get((out_dim,), "bias")?)
171 } else {
172 None
173 };
174
175 let this_unquant = Arc::new(Self {
176 weight,
177 bias,
178 all_reduce: distributed::SumAllReduce::new(comm),
179 });
180 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
181 Ok(this)
182 }
183}
184
185impl QuantMethod for RowParallelLayer {
186 fn new(_method: QuantMethodConfig) -> Result<Self>
187 where
188 Self: Sized,
189 {
190 candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
191 }
192
193 fn forward(&self, a: &Tensor) -> Result<Tensor> {
194 let mut xs = self.weight.forward(a)?;
195 xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
196 if let Some(bias) = &self.bias {
197 xs = xs.broadcast_add(bias)?;
198 }
199 Ok(xs)
200 }
201
202 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
203 let weight = self.weight.add_delta_w(delta)?;
204 Ok(Arc::new(Self {
205 weight,
206 bias: self.bias.clone(),
207 all_reduce: self.all_reduce.clone(),
208 }))
209 }
210
211 fn dequantize_w(&self) -> Result<Tensor> {
212 self.weight.dequantize_w()
213 }
214
215 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
216 self.weight.dtype_and_device()
217 }
218
219 fn begin_track_stats(&mut self) -> Result<()> {
220 Arc::get_mut(&mut self.weight)
221 .context("Failed to get &mut to weight")?
222 .begin_track_stats()
223 }
224
225 fn end_track_stats(&self) -> Result<Tensor> {
226 self.weight.end_track_stats()
227 }
228
229 fn quantized_act_type(&self) -> Option<candle_core::DType> {
230 self.weight.quantized_act_type()
231 }
232
233 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
234 self.weight.unquant_weight_bias()
235 }
236
237 fn apply_isq(
238 self: Arc<Self>,
239 dtype: Option<crate::IsqType>,
240 device: candle_core::Device,
241 n_quantized: &std::sync::atomic::AtomicUsize,
242 imatrix_weight: Option<Vec<f32>>,
243 guard: QuantizeOntoGuard,
244 ) -> Result<Arc<dyn QuantMethod>> {
245 let weight =
246 self.weight
247 .clone()
248 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
249 let bias = match &self.bias {
250 Some(b) => {
251 let (dtype, device) = weight.dtype_and_device();
252 Some(b.to_device(&device)?.to_dtype(dtype)?)
253 }
254 None => None,
255 };
256 Ok(Arc::new(Self {
257 weight,
258 bias,
259 all_reduce: self.all_reduce.clone(),
260 }))
261 }
262
263 fn is_distributed(&self) -> Option<DistributedKind> {
264 Some(DistributedKind::RowParallel)
265 }
266}
267
268impl QuantizedSerde for RowParallelLayer {
269 fn isq_serde_supported(&self) -> bool {
270 self.weight.isq_serde_supported()
271 }
272 fn name(&self) -> &'static str {
273 self.weight.name()
274 }
275 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
276 self.weight.serialize_with_bias(self.bias.clone())
277 }
278 fn deserialize(
279 data: std::borrow::Cow<[u8]>,
280 device: &candle_core::Device,
281 comm: &Arc<crate::Comm>,
282 guard: QuantizeOntoGuard,
283 ) -> Result<Arc<dyn QuantMethod>>
284 where
285 Self: Sized,
286 {
287 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
289 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
290 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
291 QuantizedSerdeType::Unquant => {
292 UnquantLinear::deserialize_ext_bias(data, device, guard)?
293 }
294 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
295 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
296 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
297 };
298 Ok(Arc::new(Self {
299 weight,
300 bias,
301 all_reduce: SumAllReduce::new(comm),
302 }))
303 }
304}
305
306#[derive(Debug)]
307pub struct ColumnParallelLayer {
310 weight: Arc<dyn QuantMethod>,
311 bias: Option<Tensor>,
312}
313
314impl ColumnParallelLayer {
315 #[allow(clippy::new_ret_no_self)]
316 pub fn new_with_shard(
317 in_dim: usize,
318 out_dim: usize,
319 config: &Option<QuantizedConfig>,
320 bias: bool,
321 comm: &Arc<crate::Comm>,
322 shard: Shard,
323 vb: ShardedVarBuilder,
324 ) -> Result<Arc<dyn QuantMethod>> {
325 let base_vb = vb.clone();
326 let vb = if should_apply_immediate_isq(&vb) {
327 vb.set_device(Device::Cpu)
328 } else {
329 vb
330 };
331
332 let weight = if let Some(quant_conf) = &config {
333 if matches!(
335 quant_conf,
336 QuantizedConfig::GptqAwq { .. }
337 | QuantizedConfig::Bitsandbytes { .. }
338 | QuantizedConfig::Afq { .. }
339 ) && comm.world_size() != 1
340 {
341 candle_core::bail!(
342 "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
343 comm.world_size()
344 );
345 }
346
347 match quant_conf {
348 QuantizedConfig::GptqAwq { .. } => {
349 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
350 }
351 QuantizedConfig::Fp8 { .. } => {
352 blockwise_fp8_linear_b(in_dim, out_dim, quant_conf, false, shard, vb.clone())?
354 }
355 QuantizedConfig::Bitsandbytes { .. } => {
356 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
357 }
358 QuantizedConfig::Afq { .. } => {
359 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
360 }
361 }
362 } else {
363 if !vb.contains_tensor("weight") {
365 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
366 Arc::new(layer) as Arc<dyn QuantMethod>
367 } else {
368 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
369 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
370
371 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
372 Linear::new(weight, None),
373 ))?;
374 Arc::new(layer) as Arc<dyn QuantMethod>
375 }
376 };
377
378 let bias = if bias && vb.contains_tensor("bias") {
380 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
381 } else {
382 None
383 };
384
385 let this_unquant = Arc::new(Self { weight, bias });
386 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
387 Ok(this)
388 }
389
390 #[allow(clippy::new_ret_no_self)]
391 pub fn new(
392 in_dim: usize,
393 out_dim: usize,
394 config: &Option<QuantizedConfig>,
395 bias: bool,
396 comm: &Arc<crate::Comm>,
397 vb: ShardedVarBuilder,
398 ) -> Result<Arc<dyn QuantMethod>> {
399 let rank = comm.rank();
400 let world_size = comm.world_size();
401 let shard = shard(0, rank, world_size);
402
403 Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
404 }
405
406 #[allow(clippy::new_ret_no_self)]
407 pub fn new_matformer(
408 in_dim: usize,
409 out_dim: usize,
410 orig_intermediate_size: usize,
411 config: &Option<QuantizedConfig>,
412 bias: bool,
413 comm: &Arc<crate::Comm>,
414 vb: ShardedVarBuilder,
415 ) -> Result<Arc<dyn QuantMethod>> {
416 let rank = comm.rank();
417 let world_size = comm.world_size();
418 let shard = shard(0, rank, world_size);
419
420 let base_vb = vb.clone();
421 let vb = if should_apply_immediate_isq(&vb) {
422 vb.set_device(Device::Cpu)
423 } else {
424 vb
425 };
426
427 if config.is_some() {
428 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
429 }
430
431 let weight = if !vb.contains_tensor("weight") {
433 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
434 Arc::new(layer) as Arc<dyn QuantMethod>
435 } else {
436 let weight = vb
437 .get_with_hints(
438 (orig_intermediate_size, in_dim),
439 "weight",
440 Default::default(),
441 )?
442 .i((..out_dim, ..))?
443 .contiguous()?;
444
445 let weight = shard.apply_to(&weight)?;
446 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
447
448 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
449 Linear::new(weight, None),
450 ))?;
451 Arc::new(layer) as Arc<dyn QuantMethod>
452 };
453
454 let bias = if bias && vb.contains_tensor("bias") {
456 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
457 } else {
458 None
459 };
460
461 let this_unquant = Arc::new(Self { weight, bias });
462 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
463 Ok(this)
464 }
465
466 pub fn new_merged(
467 in_dim: usize,
468 out_dim: usize,
469 chunks: usize,
470 config: &Option<QuantizedConfig>,
471 bias: bool,
472 comm: &Arc<crate::Comm>,
473 vb: ShardedVarBuilder,
474 ) -> Result<Vec<Arc<dyn QuantMethod>>> {
475 let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
476 for chunk_idx in 0..chunks {
477 let layer = ColumnParallelLayer::new_with_shard(
478 in_dim,
479 out_dim,
480 config,
481 bias,
482 comm,
483 shard(
484 0,
485 chunk_idx * comm.world_size() + comm.rank(),
486 chunks * comm.world_size(),
487 ),
488 vb.clone(),
489 )?;
490 vec_layers.push(layer);
491 }
492 Ok(vec_layers)
493 }
494}
495
496impl QuantMethod for ColumnParallelLayer {
497 fn new(_method: QuantMethodConfig) -> Result<Self>
498 where
499 Self: Sized,
500 {
501 candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
502 }
503
504 fn forward(&self, a: &Tensor) -> Result<Tensor> {
505 let mut xs = self.weight.forward(a)?;
506 if let Some(bias) = &self.bias {
507 xs = xs.broadcast_add(bias)?;
508 }
509 Ok(xs)
510 }
511
512 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
513 let weight = self.weight.add_delta_w(delta)?;
514 Ok(Arc::new(Self {
515 weight,
516 bias: self.bias.clone(),
517 }))
518 }
519
520 fn dequantize_w(&self) -> Result<Tensor> {
521 self.weight.dequantize_w()
522 }
523
524 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
525 self.weight.dtype_and_device()
526 }
527
528 fn begin_track_stats(&mut self) -> Result<()> {
529 Arc::get_mut(&mut self.weight)
530 .context("Failed to get &mut to weight")?
531 .begin_track_stats()
532 }
533
534 fn end_track_stats(&self) -> Result<Tensor> {
535 self.weight.end_track_stats()
536 }
537
538 fn quantized_act_type(&self) -> Option<candle_core::DType> {
539 self.weight.quantized_act_type()
540 }
541
542 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
543 self.weight.unquant_weight_bias()
544 }
545
546 fn apply_isq(
547 self: Arc<Self>,
548 dtype: Option<crate::IsqType>,
549 device: candle_core::Device,
550 n_quantized: &std::sync::atomic::AtomicUsize,
551 imatrix_weight: Option<Vec<f32>>,
552 guard: QuantizeOntoGuard,
553 ) -> Result<Arc<dyn QuantMethod>> {
554 let weight =
555 self.weight
556 .clone()
557 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
558 let bias = match &self.bias {
559 Some(b) => {
560 let (dtype, device) = weight.dtype_and_device();
561 Some(b.to_device(&device)?.to_dtype(dtype)?)
562 }
563 None => None,
564 };
565 Ok(Arc::new(Self { weight, bias }))
566 }
567
568 fn is_distributed(&self) -> Option<DistributedKind> {
569 Some(DistributedKind::ColumnParallel)
570 }
571}
572
573impl QuantizedSerde for ColumnParallelLayer {
574 fn isq_serde_supported(&self) -> bool {
575 self.weight.isq_serde_supported()
576 }
577 fn name(&self) -> &'static str {
578 self.weight.name()
579 }
580 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
581 self.weight.serialize_with_bias(self.bias.clone())
582 }
583 fn deserialize(
584 data: std::borrow::Cow<[u8]>,
585 device: &candle_core::Device,
586 _comm: &Arc<crate::Comm>,
587 guard: QuantizeOntoGuard,
588 ) -> Result<Arc<dyn QuantMethod>>
589 where
590 Self: Sized,
591 {
592 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
594 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
595 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
596 QuantizedSerdeType::Unquant => {
597 UnquantLinear::deserialize_ext_bias(data, device, guard)?
598 }
599 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
600 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
601 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
602 };
603 Ok(Arc::new(Self { weight, bias }))
604 }
605}
606
607#[derive(Debug)]
608pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
610
611impl ReplicatedLayer {
612 pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
613 let dev = lin.weight().device().clone();
614 let this_unquant = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
615 let this: Arc<dyn QuantMethod> = apply_immediate_isq_always(this_unquant, &dev)?;
616 Ok(this)
617 }
618
619 #[allow(clippy::new_ret_no_self)]
620 pub fn new(
621 in_dim: usize,
622 out_dim: usize,
623 config: &Option<QuantizedConfig>,
624 bias: bool,
625 vb: ShardedVarBuilder,
626 ) -> Result<Arc<dyn QuantMethod>> {
627 let base_vb = vb.clone();
628 let vb = if should_apply_immediate_isq(&vb) {
629 vb.set_device(Device::Cpu)
630 } else {
631 vb
632 };
633
634 let layer = if let Some(quant_conf) = &config {
635 match quant_conf {
636 QuantizedConfig::GptqAwq { .. } => {
637 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
638 }
639 QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
640 in_dim,
641 out_dim,
642 quant_conf,
643 bias,
644 Default::default(),
645 vb.clone(),
646 )?,
647 QuantizedConfig::Bitsandbytes { .. } => {
648 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
649 }
650 QuantizedConfig::Afq { .. } => {
651 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
652 }
653 }
654 } else {
655 if !vb.contains_tensor("weight") {
657 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
658 Arc::new(layer) as Arc<dyn QuantMethod>
659 } else {
660 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
661 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
662
663 let bias = if bias {
664 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
665 } else {
666 None
667 };
668 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
669 Linear::new(weight, bias),
670 ))?;
671 Arc::new(layer) as Arc<dyn QuantMethod>
672 }
673 };
674
675 let this_unquant = Arc::new(Self(layer));
676 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
677 Ok(this)
678 }
679
680 #[allow(clippy::new_ret_no_self)]
681 pub fn new_layers_matformer_indices(
682 in_dim: usize,
683 out_dim: usize,
684 kept_layers_indices: Option<&Tensor>,
685 orig_num_hidden_layers: usize,
686 config: &Option<QuantizedConfig>,
687 bias: bool,
688 vb: ShardedVarBuilder,
689 ) -> Result<Arc<dyn QuantMethod>> {
690 let base_vb = vb.clone();
691 let vb = if should_apply_immediate_isq(&vb) {
692 vb.set_device(Device::Cpu)
693 } else {
694 vb
695 };
696
697 let layer = if let Some(quant_conf) = &config {
698 if kept_layers_indices.is_some() {
699 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
700 }
701
702 match quant_conf {
703 QuantizedConfig::GptqAwq { .. } => {
704 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
705 }
706 QuantizedConfig::Fp8 { .. } => blockwise_fp8_linear_b(
707 in_dim,
708 out_dim,
709 quant_conf,
710 bias,
711 Default::default(),
712 vb.clone(),
713 )?,
714 QuantizedConfig::Bitsandbytes { .. } => {
715 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
716 }
717 QuantizedConfig::Afq { .. } => {
718 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
719 }
720 }
721 } else {
722 if !vb.contains_tensor("weight") {
724 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
725 Arc::new(layer) as Arc<dyn QuantMethod>
726 } else {
727 let mut weight =
728 vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
729
730 if let Some(kept_layers_indices) = &kept_layers_indices {
731 let weight_reshaped = weight.reshape((
732 orig_num_hidden_layers,
733 weight.dim(0)? / orig_num_hidden_layers,
734 weight.dim(1)?,
735 ))?;
736
737 weight = weight_reshaped
738 .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
739 .reshape(((), weight_reshaped.dim(D::Minus1)?))?
740 .contiguous()?;
741 }
742
743 weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
744
745 let bias = if bias {
746 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
747 } else {
748 None
749 };
750 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
751 Linear::new(weight, bias),
752 ))?;
753 Arc::new(layer) as Arc<dyn QuantMethod>
754 }
755 };
756
757 let this_unquant = Arc::new(Self(layer));
758 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
759 Ok(this)
760 }
761}
762
763impl QuantMethod for ReplicatedLayer {
764 fn new(_method: QuantMethodConfig) -> Result<Self>
765 where
766 Self: Sized,
767 {
768 candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
769 }
770
771 fn forward(&self, a: &Tensor) -> Result<Tensor> {
772 self.0.forward(a)
773 }
774
775 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
776 self.0.add_delta_w(delta)
777 }
778
779 fn dequantize_w(&self) -> Result<Tensor> {
780 self.0.dequantize_w()
781 }
782
783 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
784 self.0.dtype_and_device()
785 }
786
787 fn begin_track_stats(&mut self) -> Result<()> {
788 Arc::get_mut(&mut self.0)
789 .context("Failed to get &mut to weight")?
790 .begin_track_stats()
791 }
792
793 fn end_track_stats(&self) -> Result<Tensor> {
794 self.0.end_track_stats()
795 }
796
797 fn quantized_act_type(&self) -> Option<candle_core::DType> {
798 self.0.quantized_act_type()
799 }
800
801 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
802 self.0.unquant_weight_bias()
803 }
804
805 fn apply_isq(
806 self: Arc<Self>,
807 dtype: Option<crate::IsqType>,
808 device: candle_core::Device,
809 n_quantized: &std::sync::atomic::AtomicUsize,
810 imatrix_weight: Option<Vec<f32>>,
811 guard: QuantizeOntoGuard,
812 ) -> Result<Arc<dyn QuantMethod>> {
813 self.0
814 .clone()
815 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
816 }
817
818 fn is_distributed(&self) -> Option<DistributedKind> {
819 Some(DistributedKind::Replicated)
820 }
821}
822
823impl QuantizedSerde for ReplicatedLayer {
824 fn isq_serde_supported(&self) -> bool {
825 self.0.isq_serde_supported()
826 }
827 fn name(&self) -> &'static str {
828 self.0.name()
829 }
830 fn serialize(&self) -> Result<std::borrow::Cow<[u8]>> {
831 self.0.serialize()
832 }
833 fn deserialize(
834 data: std::borrow::Cow<[u8]>,
835 device: &candle_core::Device,
836 comm: &Arc<crate::Comm>,
837 guard: QuantizeOntoGuard,
838 ) -> Result<Arc<dyn QuantMethod>>
839 where
840 Self: Sized,
841 {
842 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
844 let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
845 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
846 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
847 QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
848 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
849 QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
850 };
851 Ok(Arc::new(Self(deserialized)))
852 }
853}
854
855#[derive(Debug)]
856pub struct PackedExperts {
857 pub gate_proj: Vec<Arc<dyn QuantMethod>>,
858 pub up_proj: Vec<Arc<dyn QuantMethod>>,
859 pub down_proj: Vec<Arc<dyn QuantMethod>>,
860}
861
862impl PackedExperts {
863 #[allow(clippy::too_many_arguments)]
865 pub fn new(
866 num_local_experts: usize,
867 hidden_size: usize,
868 intermediate_size: usize,
869 config: &Option<QuantizedConfig>,
870 bias: bool,
871 comm: &Arc<crate::Comm>,
872 vb: ShardedVarBuilder,
873 ) -> Result<Self> {
874 if bias {
875 candle_core::bail!("PackedExperts does not support bias.");
876 }
877
878 let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
879 if comm.world_size() != 1 {
881 candle_core::bail!(
882 "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
883 comm.world_size()
884 );
885 }
886
887 match quant_conf {
888 QuantizedConfig::Afq { .. } => {
889 if !vb.contains_tensor("gate_up_proj")
890 || !vb.contains_tensor("gate_up_proj.weight")
891 {
892 candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
893 }
894
895 let base_vb = vb.clone();
896
897 let vb_gate_proj = if should_apply_immediate_isq(&vb) {
898 vb.pp("gate_proj").set_device(Device::Cpu)
899 } else {
900 vb.pp("gate_proj")
901 };
902 let vb_up_proj = if should_apply_immediate_isq(&vb) {
903 vb.pp("up_proj").set_device(Device::Cpu)
904 } else {
905 vb.pp("up_proj")
906 };
907 let vb_down_proj = if should_apply_immediate_isq(&vb) {
908 vb.pp("down_proj").set_device(Device::Cpu)
909 } else {
910 vb.pp("down_proj")
911 };
912 let mut gate_proj = AfqLayer::afq_packed_linear_b(
913 num_local_experts,
914 hidden_size,
915 intermediate_size,
916 quant_conf,
917 bias,
918 vb_gate_proj,
919 )?;
920 let mut up_proj = AfqLayer::afq_packed_linear_b(
921 num_local_experts,
922 hidden_size,
923 intermediate_size,
924 quant_conf,
925 bias,
926 vb_up_proj,
927 )?;
928 let mut down_proj = AfqLayer::afq_packed_linear_b(
929 num_local_experts,
930 intermediate_size,
931 hidden_size,
932 quant_conf,
933 bias,
934 vb_down_proj,
935 )?;
936
937 gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
938 up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
939 down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
940
941 (vec![gate_proj], vec![up_proj], vec![down_proj])
942 }
943 _ => candle_core::bail!(
944 "PackedExperts with quantization config only allows AFQ quantization"
945 ),
946 }
947 } else if !vb.contains_tensor("gate_up_proj") {
948 let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
950 let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
951 let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
952 for _ in 0..num_local_experts {
953 gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
954 us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
955 ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
956 }
957 (gs, us, ds)
958 } else {
959 let gate_up_block_size = intermediate_size / comm.world_size();
967 let gate_up_start = gate_up_block_size * comm.rank();
968
969 let shard_gate = Shard::Offset {
971 dim: 2,
972 offset: gate_up_start,
973 len: gate_up_block_size,
974 };
975 let shard_up = Shard::Offset {
976 dim: 2,
977 offset: intermediate_size + gate_up_start,
978 len: gate_up_block_size,
979 };
980 let shard_down = Shard::Simple {
981 dim: 1,
982 rank: comm.rank(),
983 world_size: comm.world_size(),
984 };
985
986 let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
987 vb.pp("gate_up_proj").set_device(Device::Cpu)
988 } else {
989 vb.pp("gate_up_proj")
990 };
991 let vb_down_proj = if should_apply_immediate_isq(&vb) {
992 vb.pp("down_proj").set_device(Device::Cpu)
993 } else {
994 vb.pp("down_proj")
995 };
996
997 let gate_proj = vb
998 .get_with_hints(
999 (num_local_experts, hidden_size, intermediate_size * 2),
1000 "gate_up_proj",
1001 shard_gate,
1002 )?
1003 .t()?
1004 .contiguous()?;
1005 let up_proj = vb
1006 .get_with_hints(
1007 (num_local_experts, hidden_size, intermediate_size * 2),
1008 "gate_up_proj",
1009 shard_up,
1010 )?
1011 .t()?
1012 .contiguous()?;
1013 let down_proj = vb
1014 .get_with_hints(
1015 (num_local_experts, intermediate_size, hidden_size),
1016 "down_proj",
1017 shard_down,
1018 )?
1019 .t()?
1020 .contiguous()?;
1021
1022 let gc = gate_proj.chunk(num_local_experts, 0)?;
1023 let uc = up_proj.chunk(num_local_experts, 0)?;
1024 let dc = down_proj.chunk(num_local_experts, 0)?;
1025 drop((gate_proj, up_proj, down_proj));
1026
1027 let mut gs = Vec::new();
1028 let mut us = Vec::new();
1029 let mut ds = Vec::new();
1030 for ((mut gate_proj, mut up_proj), mut down_proj) in
1031 gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
1032 {
1033 gate_proj = gate_proj.squeeze(0)?;
1034 up_proj = up_proj.squeeze(0)?;
1035 down_proj = down_proj.squeeze(0)?;
1036 let gate_proj = merge_lora_weights(
1037 &vb,
1038 gate_proj,
1039 hidden_size,
1040 intermediate_size * 2,
1041 shard_gate,
1042 )?;
1043 let up_proj =
1044 merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1045 let down_proj =
1046 merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1047
1048 let mut gate_proj: Arc<dyn QuantMethod> =
1049 Arc::new(<UnquantLinear as QuantMethod>::new(
1050 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1051 )?);
1052 gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1053 let mut up_proj: Arc<dyn QuantMethod> =
1054 Arc::new(<UnquantLinear as QuantMethod>::new(
1055 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1056 )?);
1057 up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1058 let mut down_proj: Arc<dyn QuantMethod> =
1059 Arc::new(<UnquantLinear as QuantMethod>::new(
1060 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1061 )?);
1062 down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1063 gs.push(gate_proj);
1064 us.push(up_proj);
1065 ds.push(down_proj);
1066 }
1067 (gs, us, ds)
1068 };
1069
1070 Ok(Self {
1071 gate_proj,
1072 up_proj,
1073 down_proj,
1074 })
1075 }
1076}
1077
1078pub struct FusedExperts {
1079 pub fused_gate_proj: Arc<dyn QuantMethod>,
1080 pub fused_up_proj: Arc<dyn QuantMethod>,
1081 pub fused_down_proj: Arc<dyn QuantMethod>,
1082}
1083
1084impl FusedExperts {
1085 pub fn new(
1086 hidden_size: usize,
1087 moe_intermediate_size: usize,
1088 num_experts: usize,
1089 quantization_config: &Option<QuantizedConfig>,
1090 vb: ShardedVarBuilder,
1091 ) -> Result<Self> {
1092 if !vb.device().is_metal() {
1093 candle_core::bail!("FastMoeMlp requires Metal.");
1094 }
1095
1096 let (fused_gate_proj, fused_up_proj, fused_down_proj) =
1097 if matches!(&quantization_config, Some(QuantizedConfig::Afq { .. })) {
1098 let quantization_config = quantization_config.as_ref().unwrap();
1099
1100 let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1101 num_experts,
1102 hidden_size,
1103 moe_intermediate_size,
1104 quantization_config,
1105 false,
1106 vb.pp("switch_mlp.gate_proj"),
1107 )?;
1108 let fused_up_proj = AfqLayer::afq_packed_linear_b(
1109 num_experts,
1110 hidden_size,
1111 moe_intermediate_size,
1112 quantization_config,
1113 false,
1114 vb.pp("switch_mlp.up_proj"),
1115 )?;
1116 let fused_down_proj = AfqLayer::afq_packed_linear_b(
1117 num_experts,
1118 moe_intermediate_size,
1119 hidden_size,
1120 quantization_config,
1121 false,
1122 vb.pp("switch_mlp.down_proj"),
1123 )?;
1124
1125 (fused_gate_proj, fused_up_proj, fused_down_proj)
1126 } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1127 let experts_vb = vb.pp("experts");
1128 let mut gate_proj_vec = Vec::new();
1129 let mut up_proj_vec = Vec::new();
1130 let mut down_proj_vec = Vec::new();
1131 for i in 0..num_experts {
1132 let vb = experts_vb.pp(i);
1133
1134 let gate_proj = crate::linear_no_bias(
1135 hidden_size,
1136 moe_intermediate_size,
1137 quantization_config,
1138 vb.pp("gate_proj.weight"),
1139 )?;
1140 let up_proj = crate::linear_no_bias(
1141 hidden_size,
1142 moe_intermediate_size,
1143 quantization_config,
1144 vb.pp("up_proj.weight"),
1145 )?;
1146 let down_proj = crate::linear_no_bias(
1147 moe_intermediate_size,
1148 hidden_size,
1149 quantization_config,
1150 vb.pp("down_proj.weight"),
1151 )?;
1152
1153 gate_proj_vec.push(gate_proj.dequantize_w()?);
1154 up_proj_vec.push(up_proj.dequantize_w()?);
1155 down_proj_vec.push(down_proj.dequantize_w()?);
1156 }
1157
1158 let mut gate_proj: Arc<dyn QuantMethod> =
1159 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1160 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1161 ))?);
1162 let mut up_proj: Arc<dyn QuantMethod> =
1163 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1164 Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
1165 ))?);
1166 let mut down_proj: Arc<dyn QuantMethod> =
1167 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1168 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1169 ))?);
1170 gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1171 up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1172 down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1173
1174 (gate_proj, up_proj, down_proj)
1175 } else {
1176 let experts_vb = vb.pp("experts");
1177 let mut gate_proj_vec = Vec::new();
1178 let mut up_proj_vec = Vec::new();
1179 let mut down_proj_vec = Vec::new();
1180 for i in 0..num_experts {
1181 let vb = experts_vb.pp(i);
1182 let gate_proj =
1183 vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1184 let up_proj = vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1185 let down_proj =
1186 vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1187
1188 gate_proj_vec.push(gate_proj);
1189 up_proj_vec.push(up_proj);
1190 down_proj_vec.push(down_proj);
1191 }
1192
1193 let mut gate_proj: Arc<dyn QuantMethod> =
1194 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1195 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1196 ))?);
1197 let mut up_proj: Arc<dyn QuantMethod> =
1198 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1199 Linear::new(Tensor::stack(&up_proj_vec, 0)?, None),
1200 ))?);
1201 let mut down_proj: Arc<dyn QuantMethod> =
1202 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1203 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1204 ))?);
1205 gate_proj = apply_immediate_isq(gate_proj, vb.pp("gate_proj"))?;
1206 up_proj = apply_immediate_isq(up_proj, vb.pp("up_proj"))?;
1207 down_proj = apply_immediate_isq(down_proj, vb.pp("down_proj"))?;
1208
1209 (gate_proj, up_proj, down_proj)
1210 };
1211
1212 Ok(Self {
1213 fused_gate_proj,
1214 fused_up_proj,
1215 fused_down_proj,
1216 })
1217 }
1218}
1219
1220pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1222 if comm.world_size() == 1 {
1223 return Shard::default();
1224 }
1225
1226 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1230 comm.world_size() / total_num_kv_heads
1231 } else {
1232 return Shard::Simple {
1233 dim: 0,
1234 rank: comm.rank(),
1235 world_size: comm.world_size(),
1236 };
1237 };
1238
1239 let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
1240 let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
1241 Shard::Offset {
1242 dim: 0,
1243 offset: kv_shard_id * head_dim,
1244 len: head_dim,
1245 }
1246}
1247
1248pub fn compute_n_kv_groups(
1250 total_num_kv_heads: usize,
1251 num_attention_heads: usize,
1252 comm: &Comm,
1253) -> usize {
1254 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1255 comm.world_size() / total_num_kv_heads
1256 } else {
1257 1
1258 };
1259 if kv_replicate != 0 {
1260 (num_attention_heads / total_num_kv_heads) / kv_replicate
1261 } else {
1262 num_attention_heads / total_num_kv_heads
1263 }
1264}