mistralrs_quant/distributed/
mod.rs

1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11    master_ip: Option<String>,
12    pub master_port: u16,
13    pub port: u16,
14    pub right_port: u16,
15    right_ip: Option<String>,
16    pub rank: usize,
17    pub world_size: usize,
18}
19
20impl RingConfig {
21    /// Loads the ring backend config from a path at `RING_CONFIG`
22    pub fn load() -> Self {
23        let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24        let config: RingConfig = serde_json::from_reader(
25            &File::open(config_json).expect("Could not access Ring config JSON"),
26        )
27        .expect("Invalid JSON config");
28
29        if config.master_ip.is_none() && !config.is_master_rank() {
30            panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31        }
32        config
33    }
34
35    pub fn is_master_rank(&self) -> bool {
36        self.rank == 0
37    }
38
39    pub fn master_ip(&self) -> String {
40        self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41    }
42
43    pub fn right_ip(&self) -> String {
44        self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45    }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49    fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53    fn wait(&self) -> Result<()> {
54        Barrier::wait(self);
55        Ok(())
56    }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60    #[cfg(all(feature = "cuda", feature = "ring"))]
61    {
62        use candle_core::cuda::WrapErr;
63        candle_core::cuda::cudarc::driver::result::device::get_count()
64            .w()
65            .map(|x| x as usize)
66    }
67    #[cfg(all(not(feature = "cuda"), feature = "ring"))]
68    {
69        let config = RingConfig::load();
70        Ok(config.world_size)
71    }
72
73    #[cfg(all(feature = "cuda", feature = "nccl"))]
74    {
75        // In case we have manual set of TP size
76        if let Ok(x) = std::env::var("MISTRALRS_MN_LOCAL_WORLD_SIZE") {
77            use std::str::FromStr;
78            Ok(usize::from_str(&x).expect("Not a number for MISTRALRS_MN_LOCAL_WORLD_SIZE!"))
79        } else {
80            use candle_core::cuda::WrapErr;
81            candle_core::cuda::cudarc::driver::result::device::get_count()
82                .w()
83                .map(|x| x as usize)
84        }
85    }
86
87    #[cfg(all(not(feature = "ring"), not(feature = "nccl")))]
88    Ok(1)
89}
90
91pub fn use_nccl() -> bool {
92    (std::env::var("MISTRALRS_NO_NCCL").is_err()
93        || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
94        && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
95}
96
97// Unified Comm enum
98#[derive(Debug)]
99pub enum Comm {
100    #[cfg(all(feature = "cuda", feature = "nccl"))]
101    Nccl(nccl::NcclComm),
102    #[cfg(feature = "ring")]
103    Ring(ring::RingComm),
104    Dummy(dummy::DummyComm),
105}
106
107impl Comm {
108    pub fn from_device(
109        id: Id,
110        dev: &candle_core::Device,
111        rank: usize,
112        world_size: usize,
113    ) -> Result<Self> {
114        #[cfg(all(feature = "cuda", feature = "nccl"))]
115        if use_nccl() {
116            return Ok(Self::Nccl(nccl::NcclComm::from_device(
117                id, dev, rank, world_size,
118            )?));
119        }
120
121        #[cfg(feature = "ring")]
122        {
123            return Ok(Self::Ring(ring::RingComm::from_device(
124                id, dev, rank, world_size,
125            )?));
126        }
127
128        #[allow(unreachable_code)]
129        Ok(Self::Dummy(dummy::DummyComm::from_device(
130            id, dev, rank, world_size,
131        )?))
132    }
133
134    pub fn rank(&self) -> usize {
135        match self {
136            #[cfg(all(feature = "cuda", feature = "nccl"))]
137            Self::Nccl(comm) => comm.rank(),
138            #[cfg(feature = "ring")]
139            Self::Ring(comm) => comm.rank(),
140            Self::Dummy(comm) => comm.rank(),
141        }
142    }
143
144    pub fn world_size(&self) -> usize {
145        match self {
146            #[cfg(all(feature = "cuda", feature = "nccl"))]
147            Self::Nccl(comm) => comm.world_size(),
148            #[cfg(feature = "ring")]
149            Self::Ring(comm) => comm.world_size(),
150            Self::Dummy(comm) => comm.world_size(),
151        }
152    }
153}
154
155// Unified Id enum
156#[derive(Debug, Clone, Copy)]
157pub enum Id {
158    #[cfg(all(feature = "cuda", feature = "nccl"))]
159    Nccl(cudarc::nccl::Id),
160    Dummy,
161}
162
163impl Id {
164    pub fn new() -> Self {
165        #[cfg(all(feature = "cuda", feature = "nccl"))]
166        if use_nccl() {
167            let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
168            return Self::Nccl(id);
169        }
170        Self::Dummy
171    }
172
173    pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
174        #[cfg(all(feature = "cuda", feature = "nccl"))]
175        if use_nccl() {
176            return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
177        }
178        Self::Dummy
179    }
180
181    pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
182        match self {
183            #[cfg(all(feature = "cuda", feature = "nccl"))]
184            Self::Nccl(id) => id.internal(),
185            Self::Dummy => {
186                static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
187                &ZEROED_ID
188            }
189        }
190    }
191}
192
193impl Default for Id {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199#[cfg(all(feature = "cuda", feature = "nccl"))]
200use candle_core::cuda::cudarc;
201
202// NCCL backend implementation
203#[cfg(all(feature = "cuda", feature = "nccl"))]
204mod nccl {
205    use candle_core::{cuda::cudarc, Device, Result};
206
207    #[derive(Debug)]
208    pub struct NcclComm {
209        comm: cudarc::nccl::Comm,
210    }
211
212    impl NcclComm {
213        pub fn from_device(
214            id: super::Id,
215            dev: &Device,
216            rank: usize,
217            world_size: usize,
218        ) -> Result<Self> {
219            if !super::use_nccl() {
220                candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
221            }
222            if !world_size.is_power_of_two() {
223                candle_core::bail!(
224                    "NCCL backend requires world_size to be a power of 2, got {}",
225                    world_size
226                );
227            }
228            let stream = dev.as_cuda_device()?.cuda_stream();
229            let device_ordinal = stream.context().ordinal();
230            if rank != device_ordinal {
231                candle_core::bail!(
232                    "NCCL rank {} must match device ordinal, but device ordinal is {}. \
233                     Ensure GPUs are visible in the correct order (check CUDA_VISIBLE_DEVICES).",
234                    rank,
235                    device_ordinal
236                );
237            }
238            let nccl_id = match id {
239                super::Id::Nccl(id) => id,
240                _ => candle_core::bail!("Expected NCCL Id variant for NCCL Comm initialization"),
241            };
242            tracing::info!(
243                "Initializing NCCL communicator: rank={}, world_size={}, device={}",
244                rank,
245                world_size,
246                device_ordinal
247            );
248            let comm = cudarc::nccl::Comm::from_rank(stream, rank, world_size, nccl_id)
249                .map_err(|e| candle_core::Error::debug(e.0))?;
250            Ok(Self { comm })
251        }
252
253        pub fn rank(&self) -> usize {
254            self.comm.rank()
255        }
256
257        pub fn world_size(&self) -> usize {
258            self.comm.world_size()
259        }
260
261        pub fn inner(&self) -> &cudarc::nccl::Comm {
262            &self.comm
263        }
264    }
265
266    /// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
267    unsafe impl Sync for NcclComm {}
268    unsafe impl Send for NcclComm {}
269}
270
271// Ring backend implementation
272#[cfg(feature = "ring")]
273mod ring {
274    use super::RingConfig;
275    use candle_core::{Device, Result};
276
277    #[derive(Debug)]
278    pub struct RingComm {
279        config: RingConfig,
280    }
281
282    impl RingComm {
283        pub fn from_device(
284            _id: super::Id,
285            _dev: &Device,
286            _rank: usize,
287            _world_size: usize,
288        ) -> Result<Self> {
289            let config = RingConfig::load();
290            // Validate ring configuration
291            if config.world_size < 2 {
292                candle_core::bail!(
293                    "Ring backend requires world_size >= 2, got {}",
294                    config.world_size
295                );
296            }
297            if config.rank >= config.world_size {
298                candle_core::bail!(
299                    "Ring backend invalid config: rank {} >= world_size {}",
300                    config.rank,
301                    config.world_size
302                );
303            }
304            if !config.world_size.is_power_of_two() {
305                candle_core::bail!(
306                    "Ring backend requires world_size to be a power of 2, got {}",
307                    config.world_size
308                );
309            }
310            Ok(Self { config })
311        }
312
313        pub fn rank(&self) -> usize {
314            self.config.rank
315        }
316
317        pub fn world_size(&self) -> usize {
318            self.config.world_size
319        }
320
321        pub fn config(&self) -> &RingConfig {
322            &self.config
323        }
324    }
325}
326
327// Dummy backend implementation
328mod dummy {
329    use candle_core::{Device, Result};
330
331    #[derive(Debug)]
332    pub struct DummyComm;
333
334    impl DummyComm {
335        pub fn from_device(
336            _id: super::Id,
337            _dev: &Device,
338            _rank: usize,
339            _world_size: usize,
340        ) -> Result<Self> {
341            Ok(Self)
342        }
343
344        pub fn rank(&self) -> usize {
345            0
346        }
347
348        pub fn world_size(&self) -> usize {
349            1
350        }
351    }
352}
353
354// Unified operations that work with the Comm enum
355#[derive(Clone, Debug)]
356pub struct SumAllReduce {
357    #[cfg(all(feature = "cuda", feature = "nccl"))]
358    nccl: Option<nccl_ops::SumAllReduce>,
359    #[cfg(feature = "ring")]
360    ring: Option<ring_ops::SumAllReduce>,
361    dummy: Option<dummy_ops::SumAllReduce>,
362}
363
364impl SumAllReduce {
365    pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
366        match &**comm {
367            #[cfg(all(feature = "cuda", feature = "nccl"))]
368            Comm::Nccl(_) => Self {
369                #[cfg(all(feature = "cuda", feature = "nccl"))]
370                nccl: Some(nccl_ops::SumAllReduce::new(comm)),
371                #[cfg(feature = "ring")]
372                ring: None,
373                dummy: None,
374            },
375            #[cfg(feature = "ring")]
376            Comm::Ring(_) => Self {
377                #[cfg(all(feature = "cuda", feature = "nccl"))]
378                nccl: None,
379                #[cfg(feature = "ring")]
380                ring: Some(ring_ops::SumAllReduce::new(comm)),
381                dummy: None,
382            },
383            Comm::Dummy(_) => Self {
384                #[cfg(all(feature = "cuda", feature = "nccl"))]
385                nccl: None,
386                #[cfg(feature = "ring")]
387                ring: None,
388                dummy: Some(dummy_ops::SumAllReduce::new(comm)),
389            },
390        }
391    }
392
393    pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
394        #[cfg(all(feature = "cuda", feature = "nccl"))]
395        if let Some(ref nccl) = self.nccl {
396            return nccl.sum_all_reduce(xs);
397        }
398        #[cfg(feature = "ring")]
399        if let Some(ref ring) = self.ring {
400            return ring.sum_all_reduce(xs);
401        }
402        if let Some(ref dummy) = self.dummy {
403            return dummy.sum_all_reduce(xs);
404        }
405        candle_core::bail!("No valid SumAllReduce implementation available")
406    }
407}
408
409#[derive(Clone, Debug)]
410pub struct AllGather {
411    #[cfg(all(feature = "cuda", feature = "nccl"))]
412    nccl: Option<nccl_ops::AllGather>,
413    #[cfg(feature = "ring")]
414    ring: Option<ring_ops::AllGather>,
415    dummy: Option<dummy_ops::AllGather>,
416}
417
418impl AllGather {
419    pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
420        match &**comm {
421            #[cfg(all(feature = "cuda", feature = "nccl"))]
422            Comm::Nccl(_) => Self {
423                #[cfg(all(feature = "cuda", feature = "nccl"))]
424                nccl: Some(nccl_ops::AllGather::new(comm, dim)),
425                #[cfg(feature = "ring")]
426                ring: None,
427                dummy: None,
428            },
429            #[cfg(feature = "ring")]
430            Comm::Ring(_) => Self {
431                #[cfg(all(feature = "cuda", feature = "nccl"))]
432                nccl: None,
433                #[cfg(feature = "ring")]
434                ring: Some(ring_ops::AllGather::new(comm, dim)),
435                dummy: None,
436            },
437            Comm::Dummy(_) => Self {
438                #[cfg(all(feature = "cuda", feature = "nccl"))]
439                nccl: None,
440                #[cfg(feature = "ring")]
441                ring: None,
442                dummy: Some(dummy_ops::AllGather::new(comm, dim)),
443            },
444        }
445    }
446
447    pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
448        #[cfg(all(feature = "cuda", feature = "nccl"))]
449        if let Some(ref nccl) = self.nccl {
450            return nccl.all_gather(xs);
451        }
452        #[cfg(feature = "ring")]
453        if let Some(ref ring) = self.ring {
454            return ring.all_gather(xs);
455        }
456        if let Some(ref dummy) = self.dummy {
457            return dummy.all_gather(xs);
458        }
459        candle_core::bail!("No valid AllGather implementation available")
460    }
461}
462
463// Implementation modules for each backend
464#[cfg(all(feature = "cuda", feature = "nccl"))]
465mod nccl_ops {
466    use std::{fmt::Debug, sync::Arc};
467
468    use candle_core::{
469        backend::BackendStorage, cuda::cudarc, CpuStorage, CustomOp1, DType, Layout, Result, Shape,
470        Tensor,
471    };
472
473    #[derive(Clone, Debug)]
474    pub struct SumAllReduce {
475        comm: Arc<super::Comm>,
476    }
477
478    impl SumAllReduce {
479        pub fn new(comm: &Arc<super::Comm>) -> Self {
480            Self { comm: comm.clone() }
481        }
482    }
483
484    impl SumAllReduce {
485        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
486            xs.apply_op1_no_bwd(self)
487        }
488    }
489
490    impl CustomOp1 for SumAllReduce {
491        fn name(&self) -> &'static str {
492            "SumAllReduce"
493        }
494
495        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
496            candle_core::bail!("SumAllReduce is never used on cpu")
497        }
498
499        fn cuda_fwd(
500            &self,
501            s: &candle_core::CudaStorage,
502            l: &Layout,
503        ) -> Result<(candle_core::CudaStorage, Shape)> {
504            use cudarc::nccl::ReduceOp;
505            use half::{bf16, f16};
506
507            let elem_count = l.shape().elem_count();
508            let dev = s.device().clone();
509
510            match self.comm.as_ref() {
511                super::Comm::Nccl(nccl_comm) => {
512                    let dst = match s.dtype() {
513                        DType::BF16 => {
514                            let s = s.as_cuda_slice::<bf16>()?;
515                            let s = match l.contiguous_offsets() {
516                                Some((0, l)) if l == s.len() => s,
517                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
518                            };
519                            if elem_count == 0 {
520                                candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
521                            }
522                            let device_ordinal = dev.cuda_stream().context().ordinal();
523                            if device_ordinal != nccl_comm.rank() {
524                                candle_core::bail!(
525                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
526                                     Ensure each rank uses the correct GPU.",
527                                    device_ordinal,
528                                    nccl_comm.rank()
529                                );
530                            }
531                            tracing::debug!(
532                                "NCCL all_reduce (BF16): rank={}, device={}, elem_count={}",
533                                nccl_comm.rank(),
534                                device_ordinal,
535                                elem_count
536                            );
537                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
538                            nccl_comm
539                                .inner()
540                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
541                                .map_err(candle_core::Error::debug)?;
542                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
543                        }
544                        DType::F16 => {
545                            let s = s.as_cuda_slice::<f16>()?;
546                            let s = match l.contiguous_offsets() {
547                                Some((0, l)) if l == s.len() => s,
548                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
549                            };
550                            if elem_count == 0 {
551                                candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
552                            }
553                            let device_ordinal = dev.cuda_stream().context().ordinal();
554                            if device_ordinal != nccl_comm.rank() {
555                                candle_core::bail!(
556                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
557                                     Ensure each rank uses the correct GPU.",
558                                    device_ordinal,
559                                    nccl_comm.rank()
560                                );
561                            }
562                            tracing::debug!(
563                                "NCCL all_reduce (F16): rank={}, device={}, elem_count={}",
564                                nccl_comm.rank(),
565                                device_ordinal,
566                                elem_count
567                            );
568                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
569                            nccl_comm
570                                .inner()
571                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
572                                .map_err(candle_core::Error::debug)?;
573                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
574                        }
575                        DType::F32 => {
576                            let s = s.as_cuda_slice::<f32>()?;
577                            let s = match l.contiguous_offsets() {
578                                Some((0, l)) if l == s.len() => s,
579                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
580                            };
581                            if elem_count == 0 {
582                                candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
583                            }
584                            let device_ordinal = dev.cuda_stream().context().ordinal();
585                            if device_ordinal != nccl_comm.rank() {
586                                candle_core::bail!(
587                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
588                                     Ensure each rank uses the correct GPU.",
589                                    device_ordinal,
590                                    nccl_comm.rank()
591                                );
592                            }
593                            tracing::debug!(
594                                "NCCL all_reduce (F32): rank={}, device={}, elem_count={}",
595                                nccl_comm.rank(),
596                                device_ordinal,
597                                elem_count
598                            );
599                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
600                            nccl_comm
601                                .inner()
602                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
603                                .map_err(candle_core::Error::debug)?;
604                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
605                        }
606                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
607                    };
608                    Ok((dst, l.shape().clone()))
609                }
610                _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
611            }
612        }
613    }
614
615    #[derive(Clone, Debug)]
616    pub struct AllGather {
617        comm: Arc<super::Comm>,
618        dim: usize,
619    }
620
621    impl AllGather {
622        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
623            Self {
624                comm: comm.clone(),
625                dim,
626            }
627        }
628    }
629
630    impl AllGather {
631        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
632            xs.apply_op1_no_bwd(self)
633        }
634    }
635
636    impl CustomOp1 for AllGather {
637        fn name(&self) -> &'static str {
638            "AllGather"
639        }
640
641        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
642            candle_core::bail!("AllGather is never used on cpu")
643        }
644
645        fn cuda_fwd(
646            &self,
647            s: &candle_core::CudaStorage,
648            l: &Layout,
649        ) -> Result<(candle_core::CudaStorage, Shape)> {
650            use half::{bf16, f16};
651
652            let mut out_shape = l.shape().dims().to_vec();
653            out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
654            let out_shape = Shape::from(out_shape);
655
656            let elem_count = out_shape.elem_count();
657            let dev = s.device().clone();
658
659            match self.comm.as_ref() {
660                super::Comm::Nccl(nccl_comm) => {
661                    let dst = match s.dtype() {
662                        DType::BF16 => {
663                            let s = s.as_cuda_slice::<bf16>()?;
664                            let s = match l.contiguous_offsets() {
665                                Some((0, l)) if l == s.len() => s,
666                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
667                            };
668                            if elem_count == 0 {
669                                candle_core::bail!("NCCL all_gather: elem_count must be > 0");
670                            }
671                            let device_ordinal = dev.cuda_stream().context().ordinal();
672                            if device_ordinal != nccl_comm.rank() {
673                                candle_core::bail!(
674                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
675                                     Ensure each rank uses the correct GPU.",
676                                    device_ordinal,
677                                    nccl_comm.rank()
678                                );
679                            }
680                            tracing::debug!(
681                                "NCCL all_gather (BF16): rank={}, device={}, elem_count={}",
682                                nccl_comm.rank(),
683                                device_ordinal,
684                                elem_count
685                            );
686                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
687                            nccl_comm
688                                .inner()
689                                .all_gather(s, &mut dst)
690                                .map_err(candle_core::Error::debug)?;
691                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
692                        }
693                        DType::F16 => {
694                            let s = s.as_cuda_slice::<f16>()?;
695                            let s = match l.contiguous_offsets() {
696                                Some((0, l)) if l == s.len() => s,
697                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
698                            };
699                            if elem_count == 0 {
700                                candle_core::bail!("NCCL all_gather: elem_count must be > 0");
701                            }
702                            let device_ordinal = dev.cuda_stream().context().ordinal();
703                            if device_ordinal != nccl_comm.rank() {
704                                candle_core::bail!(
705                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
706                                     Ensure each rank uses the correct GPU.",
707                                    device_ordinal,
708                                    nccl_comm.rank()
709                                );
710                            }
711                            tracing::debug!(
712                                "NCCL all_gather (F16): rank={}, device={}, elem_count={}",
713                                nccl_comm.rank(),
714                                device_ordinal,
715                                elem_count
716                            );
717                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
718                            nccl_comm
719                                .inner()
720                                .all_gather(s, &mut dst)
721                                .map_err(candle_core::Error::debug)?;
722                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
723                        }
724                        DType::F32 => {
725                            let s = s.as_cuda_slice::<f32>()?;
726                            let s = match l.contiguous_offsets() {
727                                Some((0, l)) if l == s.len() => s,
728                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
729                            };
730                            if elem_count == 0 {
731                                candle_core::bail!("NCCL all_gather: elem_count must be > 0");
732                            }
733                            let device_ordinal = dev.cuda_stream().context().ordinal();
734                            if device_ordinal != nccl_comm.rank() {
735                                candle_core::bail!(
736                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
737                                     Ensure each rank uses the correct GPU.",
738                                    device_ordinal,
739                                    nccl_comm.rank()
740                                );
741                            }
742                            tracing::debug!(
743                                "NCCL all_gather (F32): rank={}, device={}, elem_count={}",
744                                nccl_comm.rank(),
745                                device_ordinal,
746                                elem_count
747                            );
748                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
749                            nccl_comm
750                                .inner()
751                                .all_gather(s, &mut dst)
752                                .map_err(candle_core::Error::debug)?;
753                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
754                        }
755                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
756                    };
757                    Ok((dst, out_shape))
758                }
759                _ => candle_core::bail!("AllGather requires NCCL backend"),
760            }
761        }
762    }
763}
764
765// Ring operations
766#[cfg(feature = "ring")]
767mod ring_ops {
768    use std::{
769        collections::HashMap,
770        sync::{Arc, Mutex, OnceLock},
771        time::{Duration, Instant},
772    };
773
774    use std::io::{Read, Write};
775    use std::net::{TcpListener, TcpStream};
776
777    // Friendly aliases to tame type complexity.
778    type SharedTcpStream = Arc<Mutex<TcpStream>>;
779    type LeftRight = (SharedTcpStream, SharedTcpStream);
780
781    use candle_core::{
782        backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
783    };
784
785    use super::RingConfig;
786
787    // Lazily–initialized pair of TCP streams shared by every ring‑based collective op
788    static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
789
790    fn get_ring_streams(config: &RingConfig) -> LeftRight {
791        LEFT_RIGHT_STREAMS
792            .get_or_init(|| {
793                let cur_port = config.port;
794
795                let right_ip = config.right_ip();
796                let right_port = config.right_port;
797
798                let left_listener =
799                    TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
800
801                let start = Instant::now();
802                // Connect to the right neighbor using the provided IP
803                let right = loop {
804                    match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
805                        Ok(s) => break s,
806                        Err(_) if start.elapsed() > Duration::from_secs(10) => {
807                            panic!("Failed to connect to right node due to 10-second timeout");
808                        }
809                        Err(_) => continue,
810                    }
811                };
812
813                // Accept connection from the left neighbour
814                let (left, _) = left_listener.accept().expect("accept left neighbour");
815
816                left.set_nodelay(true).unwrap();
817                left.set_nonblocking(false).unwrap();
818                right.set_nodelay(true).unwrap();
819                right.set_nonblocking(false).unwrap();
820
821                (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
822            })
823            .clone()
824    }
825
826    #[derive(Clone, Debug)]
827    pub struct SumAllReduce {
828        left: SharedTcpStream,
829        right: SharedTcpStream,
830        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
831    }
832
833    impl SumAllReduce {
834        pub fn new(comm: &Arc<super::Comm>) -> Self {
835            match &**comm {
836                super::Comm::Ring(ring_comm) => {
837                    let (left, right) = get_ring_streams(ring_comm.config());
838                    Self {
839                        left,
840                        right,
841                        buffers: Arc::new(Mutex::new(HashMap::new())),
842                    }
843                }
844                _ => panic!("SumAllReduce requires Ring backend"),
845            }
846        }
847
848        fn run<T: WithDType + Copy>(
849            &self,
850            x: &[T],
851            dims: &[usize],
852            device: &Device,
853        ) -> Result<Tensor> {
854            let nbytes = x.len() * std::mem::size_of_val(x);
855
856            // --- ping‑pong to overlap latency ---------------------------------------
857            // Clone the Arc references
858            let right = self.right.clone();
859            let left = self.left.clone();
860
861            // View the local slice as bytes that can be written on the wire.
862            let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
863
864            // Re‑use (or allocate) a receive buffer of identical size.
865            let mut buffers_guard = self.buffers.lock().map_err(|e| {
866                candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
867            })?;
868            let recv_buf = buffers_guard
869                .entry(nbytes)
870                .or_insert_with(|| vec![0u8; nbytes]);
871
872            // Lock both sockets once to avoid per-call mutex overhead.
873            let mut right_guard = right.lock().map_err(|e| {
874                candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
875            })?;
876            let mut left_guard = left.lock().map_err(|e| {
877                candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
878            })?;
879
880            // For the typical tensor size we see (~ 6 KiB) a single
881            // write/read pair is faster than chunking because the extra
882            // system‑call and loop overhead dominates.  Only fall back to the
883            // chunked "ping‑pong" pipeline for larger transfers.
884            if nbytes <= 8 * 1024 {
885                // --- fast path: one shot ------------------------------------
886                right_guard
887                    .write_all(data_bytes)
888                    .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
889
890                left_guard
891                    .read_exact(recv_buf)
892                    .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
893            } else {
894                // --- slow path: chunked ping‑pong ---------------------------
895                const CHUNK_SIZE: usize = 64 * 1024; // 64 KiB
896                let mut offset = 0;
897
898                while offset < nbytes {
899                    let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
900
901                    // send this chunk to the right neighbour
902                    right_guard
903                        .write_all(&data_bytes[offset..offset + len])
904                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
905
906                    // receive the matching chunk from the left neighbour
907                    left_guard
908                        .read_exact(&mut recv_buf[offset..offset + len])
909                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
910
911                    offset += len;
912                }
913            }
914
915            drop(left_guard);
916            drop(right_guard);
917
918            // -------------------------------------------------------------------------
919            // Interpret the received bytes as a slice of T and add element‑wise into x
920            let received: &[T] =
921                unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
922
923            Tensor::from_slice(received, dims, device)
924        }
925
926        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
927            let storage = xs.storage_and_layout().0;
928            let cpu_storage = match &*storage {
929                Storage::Cpu(storage) => storage,
930                Storage::Cuda(storage) => &storage.to_cpu_storage()?,
931                Storage::Metal(storage) => &storage.to_cpu_storage()?,
932            };
933
934            let delta = match cpu_storage {
935                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
936                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
937                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
938                _ => candle_core::bail!("Unsupported dtype for ring backend"),
939            };
940
941            xs + delta
942        }
943    }
944
945    #[derive(Clone, Debug)]
946    pub struct AllGather {
947        left: SharedTcpStream,
948        right: SharedTcpStream,
949        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
950        dim: usize,
951        world_size: usize,
952        rank: usize,
953    }
954
955    impl AllGather {
956        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
957            match &**comm {
958                super::Comm::Ring(ring_comm) => {
959                    let (left, right) = get_ring_streams(ring_comm.config());
960                    Self {
961                        left,
962                        right,
963                        buffers: Arc::new(Mutex::new(HashMap::new())),
964                        dim,
965                        world_size: ring_comm.world_size(),
966                        rank: ring_comm.rank(),
967                    }
968                }
969                _ => panic!("AllGather requires Ring backend"),
970            }
971        }
972
973        fn run<T: WithDType + Copy + Default>(
974            &self,
975            x: &[T],
976            dims: &[usize],
977            device: &Device,
978        ) -> Result<Tensor> {
979            // Validate gather dimension
980            if self.dim >= dims.len() {
981                candle_core::bail!(
982                    "AllGather: invalid dimension {} for tensor of rank {}",
983                    self.dim,
984                    dims.len()
985                );
986            }
987            let elem_cnt = x.len();
988            let nbytes = elem_cnt * std::mem::size_of_val(x);
989
990            // Prepare output buffer that will hold slices from every rank.
991            let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
992
993            // Copy this rank's slice into its final slot.
994            let start = self.rank * elem_cnt;
995            out[start..start + elem_cnt].copy_from_slice(x);
996
997            let right = self.right.clone();
998            let left = self.left.clone();
999            let mut send_piece: &[T] = x;
1000
1001            for step in 0..(self.world_size - 1) {
1002                // ---------- send to the right ----------
1003                let bytes =
1004                    unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
1005                {
1006                    let mut rg = right.lock().map_err(|e| {
1007                        candle_core::Error::msg(format!(
1008                            "Failed to lock right stream mutex: {:?}",
1009                            e
1010                        ))
1011                    })?;
1012                    rg.write_all(bytes)
1013                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
1014                }
1015
1016                // ---------- receive from the left ----------
1017                let mut bg = self.buffers.lock().map_err(|e| {
1018                    candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
1019                })?;
1020                let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
1021                {
1022                    let mut lg = left.lock().map_err(|e| {
1023                        candle_core::Error::msg(format!(
1024                            "Failed to lock left stream mutex: {:?}",
1025                            e
1026                        ))
1027                    })?;
1028                    lg.read_exact(buf)
1029                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
1030                }
1031                let recv_piece: &[T] =
1032                    unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
1033
1034                // Determine which global rank the received slice came from.
1035                let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
1036                let dst = src_rank * elem_cnt;
1037                out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
1038
1039                // Forward that slice in the next iteration.
1040                send_piece = recv_piece;
1041            }
1042
1043            let mut out_dims = dims.to_vec();
1044            out_dims[self.dim] *= self.world_size;
1045            Tensor::from_slice(&out, out_dims, device)
1046        }
1047
1048        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
1049            let storage = xs.storage_and_layout().0;
1050            let cpu_storage = match &*storage {
1051                Storage::Cpu(s) => s,
1052                Storage::Cuda(s) => &s.to_cpu_storage()?,
1053                Storage::Metal(s) => &s.to_cpu_storage()?,
1054            };
1055
1056            match cpu_storage {
1057                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1058                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1059                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1060                _ => candle_core::bail!("Unsupported dtype for ring backend"),
1061            }
1062        }
1063    }
1064}
1065
1066// Dummy operations
1067mod dummy_ops {
1068    use candle_core::{Result, Tensor};
1069    use std::sync::Arc;
1070
1071    #[derive(Clone, Debug)]
1072    pub struct SumAllReduce;
1073
1074    impl SumAllReduce {
1075        pub fn new(_comm: &Arc<super::Comm>) -> Self {
1076            Self
1077        }
1078
1079        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
1080            Ok(xs.clone())
1081        }
1082    }
1083
1084    #[derive(Clone, Debug)]
1085    pub struct AllGather;
1086
1087    impl AllGather {
1088        pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
1089            Self
1090        }
1091
1092        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
1093            Ok(xs.clone())
1094        }
1095    }
1096}