1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11 master_ip: Option<String>,
12 pub master_port: u16,
13 pub port: u16,
14 pub right_port: u16,
15 right_ip: Option<String>,
16 pub rank: usize,
17 pub world_size: usize,
18}
19
20impl RingConfig {
21 pub fn load() -> Self {
23 let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24 let config: RingConfig = serde_json::from_reader(
25 &File::open(config_json).expect("Could not access Ring config JSON"),
26 )
27 .expect("Invalid JSON config");
28
29 if config.master_ip.is_none() && !config.is_master_rank() {
30 panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31 }
32 config
33 }
34
35 pub fn is_master_rank(&self) -> bool {
36 self.rank == 0
37 }
38
39 pub fn master_ip(&self) -> String {
40 self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41 }
42
43 pub fn right_ip(&self) -> String {
44 self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45 }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49 fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53 fn wait(&self) -> Result<()> {
54 Barrier::wait(self);
55 Ok(())
56 }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60 #[cfg(all(feature = "cuda", feature = "ring"))]
61 {
62 use candle_core::cuda::WrapErr;
63 candle_core::cuda::cudarc::driver::result::device::get_count()
64 .w()
65 .map(|x| x as usize)
66 }
67 #[cfg(all(not(feature = "cuda"), feature = "ring"))]
68 {
69 let config = RingConfig::load();
70 Ok(config.world_size)
71 }
72
73 #[cfg(all(feature = "cuda", feature = "nccl"))]
74 {
75 if let Ok(x) = std::env::var("MISTRALRS_MN_LOCAL_WORLD_SIZE") {
77 use std::str::FromStr;
78 Ok(usize::from_str(&x).expect("Not a number for MISTRALRS_MN_LOCAL_WORLD_SIZE!"))
79 } else {
80 use candle_core::cuda::WrapErr;
81 candle_core::cuda::cudarc::driver::result::device::get_count()
82 .w()
83 .map(|x| x as usize)
84 }
85 }
86
87 #[cfg(all(not(feature = "ring"), not(feature = "nccl")))]
88 Ok(1)
89}
90
91pub fn use_nccl() -> bool {
92 (std::env::var("MISTRALRS_NO_NCCL").is_err()
93 || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
94 && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
95}
96
97#[derive(Debug)]
99pub enum Comm {
100 #[cfg(all(feature = "cuda", feature = "nccl"))]
101 Nccl(nccl::NcclComm),
102 #[cfg(feature = "ring")]
103 Ring(ring::RingComm),
104 Dummy(dummy::DummyComm),
105}
106
107impl Comm {
108 pub fn from_device(
109 id: Id,
110 dev: &candle_core::Device,
111 rank: usize,
112 world_size: usize,
113 ) -> Result<Self> {
114 #[cfg(all(feature = "cuda", feature = "nccl"))]
115 if use_nccl() {
116 return Ok(Self::Nccl(nccl::NcclComm::from_device(
117 id, dev, rank, world_size,
118 )?));
119 }
120
121 #[cfg(feature = "ring")]
122 {
123 return Ok(Self::Ring(ring::RingComm::from_device(
124 id, dev, rank, world_size,
125 )?));
126 }
127
128 #[allow(unreachable_code)]
129 Ok(Self::Dummy(dummy::DummyComm::from_device(
130 id, dev, rank, world_size,
131 )?))
132 }
133
134 pub fn rank(&self) -> usize {
135 match self {
136 #[cfg(all(feature = "cuda", feature = "nccl"))]
137 Self::Nccl(comm) => comm.rank(),
138 #[cfg(feature = "ring")]
139 Self::Ring(comm) => comm.rank(),
140 Self::Dummy(comm) => comm.rank(),
141 }
142 }
143
144 pub fn world_size(&self) -> usize {
145 match self {
146 #[cfg(all(feature = "cuda", feature = "nccl"))]
147 Self::Nccl(comm) => comm.world_size(),
148 #[cfg(feature = "ring")]
149 Self::Ring(comm) => comm.world_size(),
150 Self::Dummy(comm) => comm.world_size(),
151 }
152 }
153}
154
155#[derive(Debug, Clone, Copy)]
157pub enum Id {
158 #[cfg(all(feature = "cuda", feature = "nccl"))]
159 Nccl(cudarc::nccl::Id),
160 Dummy,
161}
162
163impl Id {
164 pub fn new() -> Self {
165 #[cfg(all(feature = "cuda", feature = "nccl"))]
166 if use_nccl() {
167 let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
168 return Self::Nccl(id);
169 }
170 Self::Dummy
171 }
172
173 pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
174 #[cfg(all(feature = "cuda", feature = "nccl"))]
175 if use_nccl() {
176 return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
177 }
178 Self::Dummy
179 }
180
181 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
182 match self {
183 #[cfg(all(feature = "cuda", feature = "nccl"))]
184 Self::Nccl(id) => id.internal(),
185 Self::Dummy => {
186 static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
187 &ZEROED_ID
188 }
189 }
190 }
191}
192
193impl Default for Id {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[cfg(all(feature = "cuda", feature = "nccl"))]
200use candle_core::cuda::cudarc;
201
202#[cfg(all(feature = "cuda", feature = "nccl"))]
204mod nccl {
205 use candle_core::{cuda::cudarc, Device, Result};
206
207 #[derive(Debug)]
208 pub struct NcclComm {
209 comm: cudarc::nccl::Comm,
210 }
211
212 impl NcclComm {
213 pub fn from_device(
214 id: super::Id,
215 dev: &Device,
216 rank: usize,
217 world_size: usize,
218 ) -> Result<Self> {
219 if !super::use_nccl() {
220 candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
221 }
222 if !world_size.is_power_of_two() {
223 candle_core::bail!(
224 "NCCL backend requires world_size to be a power of 2, got {}",
225 world_size
226 );
227 }
228 let stream = dev.as_cuda_device()?.cuda_stream();
229 let device_ordinal = stream.context().ordinal();
230 if rank != device_ordinal {
231 candle_core::bail!(
232 "NCCL rank {} must match device ordinal, but device ordinal is {}. \
233 Ensure GPUs are visible in the correct order (check CUDA_VISIBLE_DEVICES).",
234 rank,
235 device_ordinal
236 );
237 }
238 let nccl_id = match id {
239 super::Id::Nccl(id) => id,
240 _ => candle_core::bail!("Expected NCCL Id variant for NCCL Comm initialization"),
241 };
242 tracing::info!(
243 "Initializing NCCL communicator: rank={}, world_size={}, device={}",
244 rank,
245 world_size,
246 device_ordinal
247 );
248 let comm = cudarc::nccl::Comm::from_rank(stream, rank, world_size, nccl_id)
249 .map_err(|e| candle_core::Error::debug(e.0))?;
250 Ok(Self { comm })
251 }
252
253 pub fn rank(&self) -> usize {
254 self.comm.rank()
255 }
256
257 pub fn world_size(&self) -> usize {
258 self.comm.world_size()
259 }
260
261 pub fn inner(&self) -> &cudarc::nccl::Comm {
262 &self.comm
263 }
264 }
265
266 unsafe impl Sync for NcclComm {}
268 unsafe impl Send for NcclComm {}
269}
270
271#[cfg(feature = "ring")]
273mod ring {
274 use super::RingConfig;
275 use candle_core::{Device, Result};
276
277 #[derive(Debug)]
278 pub struct RingComm {
279 config: RingConfig,
280 }
281
282 impl RingComm {
283 pub fn from_device(
284 _id: super::Id,
285 _dev: &Device,
286 _rank: usize,
287 _world_size: usize,
288 ) -> Result<Self> {
289 let config = RingConfig::load();
290 if config.world_size < 2 {
292 candle_core::bail!(
293 "Ring backend requires world_size >= 2, got {}",
294 config.world_size
295 );
296 }
297 if config.rank >= config.world_size {
298 candle_core::bail!(
299 "Ring backend invalid config: rank {} >= world_size {}",
300 config.rank,
301 config.world_size
302 );
303 }
304 if !config.world_size.is_power_of_two() {
305 candle_core::bail!(
306 "Ring backend requires world_size to be a power of 2, got {}",
307 config.world_size
308 );
309 }
310 Ok(Self { config })
311 }
312
313 pub fn rank(&self) -> usize {
314 self.config.rank
315 }
316
317 pub fn world_size(&self) -> usize {
318 self.config.world_size
319 }
320
321 pub fn config(&self) -> &RingConfig {
322 &self.config
323 }
324 }
325}
326
327mod dummy {
329 use candle_core::{Device, Result};
330
331 #[derive(Debug)]
332 pub struct DummyComm;
333
334 impl DummyComm {
335 pub fn from_device(
336 _id: super::Id,
337 _dev: &Device,
338 _rank: usize,
339 _world_size: usize,
340 ) -> Result<Self> {
341 Ok(Self)
342 }
343
344 pub fn rank(&self) -> usize {
345 0
346 }
347
348 pub fn world_size(&self) -> usize {
349 1
350 }
351 }
352}
353
354#[derive(Clone, Debug)]
356pub struct SumAllReduce {
357 #[cfg(all(feature = "cuda", feature = "nccl"))]
358 nccl: Option<nccl_ops::SumAllReduce>,
359 #[cfg(feature = "ring")]
360 ring: Option<ring_ops::SumAllReduce>,
361 dummy: Option<dummy_ops::SumAllReduce>,
362}
363
364impl SumAllReduce {
365 pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
366 match &**comm {
367 #[cfg(all(feature = "cuda", feature = "nccl"))]
368 Comm::Nccl(_) => Self {
369 #[cfg(all(feature = "cuda", feature = "nccl"))]
370 nccl: Some(nccl_ops::SumAllReduce::new(comm)),
371 #[cfg(feature = "ring")]
372 ring: None,
373 dummy: None,
374 },
375 #[cfg(feature = "ring")]
376 Comm::Ring(_) => Self {
377 #[cfg(all(feature = "cuda", feature = "nccl"))]
378 nccl: None,
379 #[cfg(feature = "ring")]
380 ring: Some(ring_ops::SumAllReduce::new(comm)),
381 dummy: None,
382 },
383 Comm::Dummy(_) => Self {
384 #[cfg(all(feature = "cuda", feature = "nccl"))]
385 nccl: None,
386 #[cfg(feature = "ring")]
387 ring: None,
388 dummy: Some(dummy_ops::SumAllReduce::new(comm)),
389 },
390 }
391 }
392
393 pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
394 #[cfg(all(feature = "cuda", feature = "nccl"))]
395 if let Some(ref nccl) = self.nccl {
396 return nccl.sum_all_reduce(xs);
397 }
398 #[cfg(feature = "ring")]
399 if let Some(ref ring) = self.ring {
400 return ring.sum_all_reduce(xs);
401 }
402 if let Some(ref dummy) = self.dummy {
403 return dummy.sum_all_reduce(xs);
404 }
405 candle_core::bail!("No valid SumAllReduce implementation available")
406 }
407}
408
409#[derive(Clone, Debug)]
410pub struct AllGather {
411 #[cfg(all(feature = "cuda", feature = "nccl"))]
412 nccl: Option<nccl_ops::AllGather>,
413 #[cfg(feature = "ring")]
414 ring: Option<ring_ops::AllGather>,
415 dummy: Option<dummy_ops::AllGather>,
416}
417
418impl AllGather {
419 pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
420 match &**comm {
421 #[cfg(all(feature = "cuda", feature = "nccl"))]
422 Comm::Nccl(_) => Self {
423 #[cfg(all(feature = "cuda", feature = "nccl"))]
424 nccl: Some(nccl_ops::AllGather::new(comm, dim)),
425 #[cfg(feature = "ring")]
426 ring: None,
427 dummy: None,
428 },
429 #[cfg(feature = "ring")]
430 Comm::Ring(_) => Self {
431 #[cfg(all(feature = "cuda", feature = "nccl"))]
432 nccl: None,
433 #[cfg(feature = "ring")]
434 ring: Some(ring_ops::AllGather::new(comm, dim)),
435 dummy: None,
436 },
437 Comm::Dummy(_) => Self {
438 #[cfg(all(feature = "cuda", feature = "nccl"))]
439 nccl: None,
440 #[cfg(feature = "ring")]
441 ring: None,
442 dummy: Some(dummy_ops::AllGather::new(comm, dim)),
443 },
444 }
445 }
446
447 pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
448 #[cfg(all(feature = "cuda", feature = "nccl"))]
449 if let Some(ref nccl) = self.nccl {
450 return nccl.all_gather(xs);
451 }
452 #[cfg(feature = "ring")]
453 if let Some(ref ring) = self.ring {
454 return ring.all_gather(xs);
455 }
456 if let Some(ref dummy) = self.dummy {
457 return dummy.all_gather(xs);
458 }
459 candle_core::bail!("No valid AllGather implementation available")
460 }
461}
462
463#[cfg(all(feature = "cuda", feature = "nccl"))]
465mod nccl_ops {
466 use std::{fmt::Debug, sync::Arc};
467
468 use candle_core::{
469 backend::BackendStorage, cuda::cudarc, CpuStorage, CustomOp1, DType, Layout, Result, Shape,
470 Tensor,
471 };
472
473 #[derive(Clone, Debug)]
474 pub struct SumAllReduce {
475 comm: Arc<super::Comm>,
476 }
477
478 impl SumAllReduce {
479 pub fn new(comm: &Arc<super::Comm>) -> Self {
480 Self { comm: comm.clone() }
481 }
482 }
483
484 impl SumAllReduce {
485 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
486 xs.apply_op1_no_bwd(self)
487 }
488 }
489
490 impl CustomOp1 for SumAllReduce {
491 fn name(&self) -> &'static str {
492 "SumAllReduce"
493 }
494
495 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
496 candle_core::bail!("SumAllReduce is never used on cpu")
497 }
498
499 fn cuda_fwd(
500 &self,
501 s: &candle_core::CudaStorage,
502 l: &Layout,
503 ) -> Result<(candle_core::CudaStorage, Shape)> {
504 use cudarc::nccl::ReduceOp;
505 use half::{bf16, f16};
506
507 let elem_count = l.shape().elem_count();
508 let dev = s.device().clone();
509
510 match self.comm.as_ref() {
511 super::Comm::Nccl(nccl_comm) => {
512 let dst = match s.dtype() {
513 DType::BF16 => {
514 let s = s.as_cuda_slice::<bf16>()?;
515 let s = match l.contiguous_offsets() {
516 Some((0, l)) if l == s.len() => s,
517 Some(_) | None => candle_core::bail!("input has to be contiguous"),
518 };
519 if elem_count == 0 {
520 candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
521 }
522 let device_ordinal = dev.cuda_stream().context().ordinal();
523 if device_ordinal != nccl_comm.rank() {
524 candle_core::bail!(
525 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
526 Ensure each rank uses the correct GPU.",
527 device_ordinal,
528 nccl_comm.rank()
529 );
530 }
531 tracing::debug!(
532 "NCCL all_reduce (BF16): rank={}, device={}, elem_count={}",
533 nccl_comm.rank(),
534 device_ordinal,
535 elem_count
536 );
537 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
538 nccl_comm
539 .inner()
540 .all_reduce(s, &mut dst, &ReduceOp::Sum)
541 .map_err(candle_core::Error::debug)?;
542 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
543 }
544 DType::F16 => {
545 let s = s.as_cuda_slice::<f16>()?;
546 let s = match l.contiguous_offsets() {
547 Some((0, l)) if l == s.len() => s,
548 Some(_) | None => candle_core::bail!("input has to be contiguous"),
549 };
550 if elem_count == 0 {
551 candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
552 }
553 let device_ordinal = dev.cuda_stream().context().ordinal();
554 if device_ordinal != nccl_comm.rank() {
555 candle_core::bail!(
556 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
557 Ensure each rank uses the correct GPU.",
558 device_ordinal,
559 nccl_comm.rank()
560 );
561 }
562 tracing::debug!(
563 "NCCL all_reduce (F16): rank={}, device={}, elem_count={}",
564 nccl_comm.rank(),
565 device_ordinal,
566 elem_count
567 );
568 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
569 nccl_comm
570 .inner()
571 .all_reduce(s, &mut dst, &ReduceOp::Sum)
572 .map_err(candle_core::Error::debug)?;
573 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
574 }
575 DType::F32 => {
576 let s = s.as_cuda_slice::<f32>()?;
577 let s = match l.contiguous_offsets() {
578 Some((0, l)) if l == s.len() => s,
579 Some(_) | None => candle_core::bail!("input has to be contiguous"),
580 };
581 if elem_count == 0 {
582 candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
583 }
584 let device_ordinal = dev.cuda_stream().context().ordinal();
585 if device_ordinal != nccl_comm.rank() {
586 candle_core::bail!(
587 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
588 Ensure each rank uses the correct GPU.",
589 device_ordinal,
590 nccl_comm.rank()
591 );
592 }
593 tracing::debug!(
594 "NCCL all_reduce (F32): rank={}, device={}, elem_count={}",
595 nccl_comm.rank(),
596 device_ordinal,
597 elem_count
598 );
599 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
600 nccl_comm
601 .inner()
602 .all_reduce(s, &mut dst, &ReduceOp::Sum)
603 .map_err(candle_core::Error::debug)?;
604 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
605 }
606 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
607 };
608 Ok((dst, l.shape().clone()))
609 }
610 _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
611 }
612 }
613 }
614
615 #[derive(Clone, Debug)]
616 pub struct AllGather {
617 comm: Arc<super::Comm>,
618 dim: usize,
619 }
620
621 impl AllGather {
622 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
623 Self {
624 comm: comm.clone(),
625 dim,
626 }
627 }
628 }
629
630 impl AllGather {
631 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
632 xs.apply_op1_no_bwd(self)
633 }
634 }
635
636 impl CustomOp1 for AllGather {
637 fn name(&self) -> &'static str {
638 "AllGather"
639 }
640
641 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
642 candle_core::bail!("AllGather is never used on cpu")
643 }
644
645 fn cuda_fwd(
646 &self,
647 s: &candle_core::CudaStorage,
648 l: &Layout,
649 ) -> Result<(candle_core::CudaStorage, Shape)> {
650 use half::{bf16, f16};
651
652 let mut out_shape = l.shape().dims().to_vec();
653 out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
654 let out_shape = Shape::from(out_shape);
655
656 let elem_count = out_shape.elem_count();
657 let dev = s.device().clone();
658
659 match self.comm.as_ref() {
660 super::Comm::Nccl(nccl_comm) => {
661 let dst = match s.dtype() {
662 DType::BF16 => {
663 let s = s.as_cuda_slice::<bf16>()?;
664 let s = match l.contiguous_offsets() {
665 Some((0, l)) if l == s.len() => s,
666 Some(_) | None => candle_core::bail!("input has to be contiguous"),
667 };
668 if elem_count == 0 {
669 candle_core::bail!("NCCL all_gather: elem_count must be > 0");
670 }
671 let device_ordinal = dev.cuda_stream().context().ordinal();
672 if device_ordinal != nccl_comm.rank() {
673 candle_core::bail!(
674 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
675 Ensure each rank uses the correct GPU.",
676 device_ordinal,
677 nccl_comm.rank()
678 );
679 }
680 tracing::debug!(
681 "NCCL all_gather (BF16): rank={}, device={}, elem_count={}",
682 nccl_comm.rank(),
683 device_ordinal,
684 elem_count
685 );
686 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
687 nccl_comm
688 .inner()
689 .all_gather(s, &mut dst)
690 .map_err(candle_core::Error::debug)?;
691 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
692 }
693 DType::F16 => {
694 let s = s.as_cuda_slice::<f16>()?;
695 let s = match l.contiguous_offsets() {
696 Some((0, l)) if l == s.len() => s,
697 Some(_) | None => candle_core::bail!("input has to be contiguous"),
698 };
699 if elem_count == 0 {
700 candle_core::bail!("NCCL all_gather: elem_count must be > 0");
701 }
702 let device_ordinal = dev.cuda_stream().context().ordinal();
703 if device_ordinal != nccl_comm.rank() {
704 candle_core::bail!(
705 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
706 Ensure each rank uses the correct GPU.",
707 device_ordinal,
708 nccl_comm.rank()
709 );
710 }
711 tracing::debug!(
712 "NCCL all_gather (F16): rank={}, device={}, elem_count={}",
713 nccl_comm.rank(),
714 device_ordinal,
715 elem_count
716 );
717 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
718 nccl_comm
719 .inner()
720 .all_gather(s, &mut dst)
721 .map_err(candle_core::Error::debug)?;
722 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
723 }
724 DType::F32 => {
725 let s = s.as_cuda_slice::<f32>()?;
726 let s = match l.contiguous_offsets() {
727 Some((0, l)) if l == s.len() => s,
728 Some(_) | None => candle_core::bail!("input has to be contiguous"),
729 };
730 if elem_count == 0 {
731 candle_core::bail!("NCCL all_gather: elem_count must be > 0");
732 }
733 let device_ordinal = dev.cuda_stream().context().ordinal();
734 if device_ordinal != nccl_comm.rank() {
735 candle_core::bail!(
736 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
737 Ensure each rank uses the correct GPU.",
738 device_ordinal,
739 nccl_comm.rank()
740 );
741 }
742 tracing::debug!(
743 "NCCL all_gather (F32): rank={}, device={}, elem_count={}",
744 nccl_comm.rank(),
745 device_ordinal,
746 elem_count
747 );
748 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
749 nccl_comm
750 .inner()
751 .all_gather(s, &mut dst)
752 .map_err(candle_core::Error::debug)?;
753 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
754 }
755 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
756 };
757 Ok((dst, out_shape))
758 }
759 _ => candle_core::bail!("AllGather requires NCCL backend"),
760 }
761 }
762 }
763}
764
765#[cfg(feature = "ring")]
767mod ring_ops {
768 use std::{
769 collections::HashMap,
770 sync::{Arc, Mutex, OnceLock},
771 time::{Duration, Instant},
772 };
773
774 use std::io::{Read, Write};
775 use std::net::{TcpListener, TcpStream};
776
777 type SharedTcpStream = Arc<Mutex<TcpStream>>;
779 type LeftRight = (SharedTcpStream, SharedTcpStream);
780
781 use candle_core::{
782 backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
783 };
784
785 use super::RingConfig;
786
787 static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
789
790 fn get_ring_streams(config: &RingConfig) -> LeftRight {
791 LEFT_RIGHT_STREAMS
792 .get_or_init(|| {
793 let cur_port = config.port;
794
795 let right_ip = config.right_ip();
796 let right_port = config.right_port;
797
798 let left_listener =
799 TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
800
801 let start = Instant::now();
802 let right = loop {
804 match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
805 Ok(s) => break s,
806 Err(_) if start.elapsed() > Duration::from_secs(10) => {
807 panic!("Failed to connect to right node due to 10-second timeout");
808 }
809 Err(_) => continue,
810 }
811 };
812
813 let (left, _) = left_listener.accept().expect("accept left neighbour");
815
816 left.set_nodelay(true).unwrap();
817 left.set_nonblocking(false).unwrap();
818 right.set_nodelay(true).unwrap();
819 right.set_nonblocking(false).unwrap();
820
821 (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
822 })
823 .clone()
824 }
825
826 #[derive(Clone, Debug)]
827 pub struct SumAllReduce {
828 left: SharedTcpStream,
829 right: SharedTcpStream,
830 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
831 }
832
833 impl SumAllReduce {
834 pub fn new(comm: &Arc<super::Comm>) -> Self {
835 match &**comm {
836 super::Comm::Ring(ring_comm) => {
837 let (left, right) = get_ring_streams(ring_comm.config());
838 Self {
839 left,
840 right,
841 buffers: Arc::new(Mutex::new(HashMap::new())),
842 }
843 }
844 _ => panic!("SumAllReduce requires Ring backend"),
845 }
846 }
847
848 fn run<T: WithDType + Copy>(
849 &self,
850 x: &[T],
851 dims: &[usize],
852 device: &Device,
853 ) -> Result<Tensor> {
854 let nbytes = x.len() * std::mem::size_of_val(x);
855
856 let right = self.right.clone();
859 let left = self.left.clone();
860
861 let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
863
864 let mut buffers_guard = self.buffers.lock().map_err(|e| {
866 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
867 })?;
868 let recv_buf = buffers_guard
869 .entry(nbytes)
870 .or_insert_with(|| vec![0u8; nbytes]);
871
872 let mut right_guard = right.lock().map_err(|e| {
874 candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
875 })?;
876 let mut left_guard = left.lock().map_err(|e| {
877 candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
878 })?;
879
880 if nbytes <= 8 * 1024 {
885 right_guard
887 .write_all(data_bytes)
888 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
889
890 left_guard
891 .read_exact(recv_buf)
892 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
893 } else {
894 const CHUNK_SIZE: usize = 64 * 1024; let mut offset = 0;
897
898 while offset < nbytes {
899 let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
900
901 right_guard
903 .write_all(&data_bytes[offset..offset + len])
904 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
905
906 left_guard
908 .read_exact(&mut recv_buf[offset..offset + len])
909 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
910
911 offset += len;
912 }
913 }
914
915 drop(left_guard);
916 drop(right_guard);
917
918 let received: &[T] =
921 unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
922
923 Tensor::from_slice(received, dims, device)
924 }
925
926 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
927 let storage = xs.storage_and_layout().0;
928 let cpu_storage = match &*storage {
929 Storage::Cpu(storage) => storage,
930 Storage::Cuda(storage) => &storage.to_cpu_storage()?,
931 Storage::Metal(storage) => &storage.to_cpu_storage()?,
932 };
933
934 let delta = match cpu_storage {
935 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
936 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
937 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
938 _ => candle_core::bail!("Unsupported dtype for ring backend"),
939 };
940
941 xs + delta
942 }
943 }
944
945 #[derive(Clone, Debug)]
946 pub struct AllGather {
947 left: SharedTcpStream,
948 right: SharedTcpStream,
949 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
950 dim: usize,
951 world_size: usize,
952 rank: usize,
953 }
954
955 impl AllGather {
956 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
957 match &**comm {
958 super::Comm::Ring(ring_comm) => {
959 let (left, right) = get_ring_streams(ring_comm.config());
960 Self {
961 left,
962 right,
963 buffers: Arc::new(Mutex::new(HashMap::new())),
964 dim,
965 world_size: ring_comm.world_size(),
966 rank: ring_comm.rank(),
967 }
968 }
969 _ => panic!("AllGather requires Ring backend"),
970 }
971 }
972
973 fn run<T: WithDType + Copy + Default>(
974 &self,
975 x: &[T],
976 dims: &[usize],
977 device: &Device,
978 ) -> Result<Tensor> {
979 if self.dim >= dims.len() {
981 candle_core::bail!(
982 "AllGather: invalid dimension {} for tensor of rank {}",
983 self.dim,
984 dims.len()
985 );
986 }
987 let elem_cnt = x.len();
988 let nbytes = elem_cnt * std::mem::size_of_val(x);
989
990 let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
992
993 let start = self.rank * elem_cnt;
995 out[start..start + elem_cnt].copy_from_slice(x);
996
997 let right = self.right.clone();
998 let left = self.left.clone();
999 let mut send_piece: &[T] = x;
1000
1001 for step in 0..(self.world_size - 1) {
1002 let bytes =
1004 unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
1005 {
1006 let mut rg = right.lock().map_err(|e| {
1007 candle_core::Error::msg(format!(
1008 "Failed to lock right stream mutex: {:?}",
1009 e
1010 ))
1011 })?;
1012 rg.write_all(bytes)
1013 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
1014 }
1015
1016 let mut bg = self.buffers.lock().map_err(|e| {
1018 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
1019 })?;
1020 let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
1021 {
1022 let mut lg = left.lock().map_err(|e| {
1023 candle_core::Error::msg(format!(
1024 "Failed to lock left stream mutex: {:?}",
1025 e
1026 ))
1027 })?;
1028 lg.read_exact(buf)
1029 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
1030 }
1031 let recv_piece: &[T] =
1032 unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
1033
1034 let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
1036 let dst = src_rank * elem_cnt;
1037 out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
1038
1039 send_piece = recv_piece;
1041 }
1042
1043 let mut out_dims = dims.to_vec();
1044 out_dims[self.dim] *= self.world_size;
1045 Tensor::from_slice(&out, out_dims, device)
1046 }
1047
1048 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
1049 let storage = xs.storage_and_layout().0;
1050 let cpu_storage = match &*storage {
1051 Storage::Cpu(s) => s,
1052 Storage::Cuda(s) => &s.to_cpu_storage()?,
1053 Storage::Metal(s) => &s.to_cpu_storage()?,
1054 };
1055
1056 match cpu_storage {
1057 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1058 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1059 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1060 _ => candle_core::bail!("Unsupported dtype for ring backend"),
1061 }
1062 }
1063 }
1064}
1065
1066mod dummy_ops {
1068 use candle_core::{Result, Tensor};
1069 use std::sync::Arc;
1070
1071 #[derive(Clone, Debug)]
1072 pub struct SumAllReduce;
1073
1074 impl SumAllReduce {
1075 pub fn new(_comm: &Arc<super::Comm>) -> Self {
1076 Self
1077 }
1078
1079 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
1080 Ok(xs.clone())
1081 }
1082 }
1083
1084 #[derive(Clone, Debug)]
1085 pub struct AllGather;
1086
1087 impl AllGather {
1088 pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
1089 Self
1090 }
1091
1092 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
1093 Ok(xs.clone())
1094 }
1095 }
1096}