mistralrs_quant/distributed/
mod.rs

1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11    master_ip: Option<String>,
12    pub master_port: u16,
13    pub port: u16,
14    pub right_port: u16,
15    right_ip: Option<String>,
16    pub rank: usize,
17    pub world_size: usize,
18}
19
20impl RingConfig {
21    /// Loads the ring backend config from a path at `RING_CONFIG`
22    pub fn load() -> Self {
23        let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24        let config: RingConfig = serde_json::from_reader(
25            &File::open(config_json).expect("Could not access Ring config JSON"),
26        )
27        .expect("Invalid JSON config");
28
29        if config.master_ip.is_none() && !config.is_master_rank() {
30            panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31        }
32        config
33    }
34
35    pub fn is_master_rank(&self) -> bool {
36        self.rank == 0
37    }
38
39    pub fn master_ip(&self) -> String {
40        self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41    }
42
43    pub fn right_ip(&self) -> String {
44        self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45    }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49    fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53    fn wait(&self) -> Result<()> {
54        Barrier::wait(self);
55        Ok(())
56    }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60    #[cfg(all(feature = "cuda", feature = "ring"))]
61    {
62        use candle_core::cuda::WrapErr;
63        candle_core::cuda::cudarc::driver::result::device::get_count()
64            .w()
65            .map(|x| x as usize)
66    }
67    #[cfg(all(not(feature = "cuda"), feature = "ring"))]
68    {
69        let config = RingConfig::load();
70        Ok(config.world_size)
71    }
72
73    #[cfg(all(feature = "cuda", feature = "nccl"))]
74    {
75        // In case we have manual set of TP size
76        if let Ok(x) = std::env::var("MISTRALRS_MN_LOCAL_WORLD_SIZE") {
77            use std::str::FromStr;
78            Ok(usize::from_str(&x).expect("Not a number for MISTRALRS_MN_LOCAL_WORLD_SIZE!"))
79        } else {
80            use candle_core::cuda::WrapErr;
81            candle_core::cuda::cudarc::driver::result::device::get_count()
82                .w()
83                .map(|x| x as usize)
84        }
85    }
86
87    #[cfg(all(not(feature = "ring"), not(feature = "nccl")))]
88    Ok(1)
89}
90
91pub fn use_nccl() -> bool {
92    (std::env::var("MISTRALRS_NO_NCCL").is_err()
93        || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
94        && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
95}
96
97// Unified Comm enum
98#[derive(Debug)]
99pub enum Comm {
100    #[cfg(all(feature = "cuda", feature = "nccl"))]
101    Nccl(nccl::NcclComm),
102    #[cfg(feature = "ring")]
103    Ring(ring::RingComm),
104    Dummy(dummy::DummyComm),
105}
106
107impl Comm {
108    pub fn from_device(
109        id: Id,
110        dev: &candle_core::Device,
111        rank: usize,
112        world_size: usize,
113    ) -> Result<Self> {
114        #[cfg(all(feature = "cuda", feature = "nccl"))]
115        if use_nccl() {
116            return Ok(Self::Nccl(nccl::NcclComm::from_device(
117                id, dev, rank, world_size,
118            )?));
119        }
120
121        #[cfg(feature = "ring")]
122        {
123            return Ok(Self::Ring(ring::RingComm::from_device(
124                id, dev, rank, world_size,
125            )?));
126        }
127
128        #[allow(unreachable_code)]
129        Ok(Self::Dummy(dummy::DummyComm::from_device(
130            id, dev, rank, world_size,
131        )?))
132    }
133
134    pub fn rank(&self) -> usize {
135        match self {
136            #[cfg(all(feature = "cuda", feature = "nccl"))]
137            Self::Nccl(comm) => comm.rank(),
138            #[cfg(feature = "ring")]
139            Self::Ring(comm) => comm.rank(),
140            Self::Dummy(comm) => comm.rank(),
141        }
142    }
143
144    pub fn world_size(&self) -> usize {
145        match self {
146            #[cfg(all(feature = "cuda", feature = "nccl"))]
147            Self::Nccl(comm) => comm.world_size(),
148            #[cfg(feature = "ring")]
149            Self::Ring(comm) => comm.world_size(),
150            Self::Dummy(comm) => comm.world_size(),
151        }
152    }
153}
154
155// Unified Id enum
156#[derive(Debug, Clone, Copy)]
157pub enum Id {
158    #[cfg(all(feature = "cuda", feature = "nccl"))]
159    Nccl(cudarc::nccl::Id),
160    Dummy,
161}
162
163impl Id {
164    pub fn new() -> Self {
165        #[cfg(all(feature = "cuda", feature = "nccl"))]
166        if use_nccl() {
167            let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
168            return Self::Nccl(id);
169        }
170        Self::Dummy
171    }
172
173    pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
174        #[cfg(all(feature = "cuda", feature = "nccl"))]
175        if use_nccl() {
176            return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
177        }
178        Self::Dummy
179    }
180
181    pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
182        match self {
183            #[cfg(all(feature = "cuda", feature = "nccl"))]
184            Self::Nccl(id) => id.internal(),
185            Self::Dummy => {
186                static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
187                &ZEROED_ID
188            }
189        }
190    }
191}
192
193impl Default for Id {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199#[cfg(all(feature = "cuda", feature = "nccl"))]
200use candle_core::cuda::cudarc;
201
202// NCCL backend implementation
203#[cfg(all(feature = "cuda", feature = "nccl"))]
204mod nccl {
205    use candle_core::{cuda::cudarc, Device, Result};
206
207    #[derive(Debug)]
208    pub struct NcclComm {
209        comm: cudarc::nccl::Comm,
210    }
211
212    impl NcclComm {
213        pub fn from_device(
214            id: super::Id,
215            dev: &Device,
216            rank: usize,
217            world_size: usize,
218        ) -> Result<Self> {
219            if !super::use_nccl() {
220                candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
221            }
222            if !world_size.is_power_of_two() {
223                candle_core::bail!(
224                    "NCCL backend requires world_size to be a power of 2, got {}",
225                    world_size
226                );
227            }
228            let stream = dev.as_cuda_device()?.cuda_stream();
229            assert_eq!(rank, stream.context().ordinal());
230            let nccl_id = match id {
231                super::Id::Nccl(id) => id,
232                _ => panic!("Expected NCCL Id variant"),
233            };
234            Ok(Self {
235                comm: cudarc::nccl::Comm::from_rank(stream, rank, world_size, nccl_id)
236                    .map_err(|e| e.0)
237                    .expect("Failed to create `Comm`, error code"),
238            })
239        }
240
241        pub fn rank(&self) -> usize {
242            self.comm.rank()
243        }
244
245        pub fn world_size(&self) -> usize {
246            self.comm.world_size()
247        }
248
249        pub fn inner(&self) -> &cudarc::nccl::Comm {
250            &self.comm
251        }
252    }
253
254    /// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
255    unsafe impl Sync for NcclComm {}
256    unsafe impl Send for NcclComm {}
257}
258
259// Ring backend implementation
260#[cfg(feature = "ring")]
261mod ring {
262    use super::RingConfig;
263    use candle_core::{Device, Result};
264
265    #[derive(Debug)]
266    pub struct RingComm {
267        config: RingConfig,
268    }
269
270    impl RingComm {
271        pub fn from_device(
272            _id: super::Id,
273            _dev: &Device,
274            _rank: usize,
275            _world_size: usize,
276        ) -> Result<Self> {
277            let config = RingConfig::load();
278            // Validate ring configuration
279            if config.world_size < 2 {
280                candle_core::bail!(
281                    "Ring backend requires world_size >= 2, got {}",
282                    config.world_size
283                );
284            }
285            if config.rank >= config.world_size {
286                candle_core::bail!(
287                    "Ring backend invalid config: rank {} >= world_size {}",
288                    config.rank,
289                    config.world_size
290                );
291            }
292            if !config.world_size.is_power_of_two() {
293                candle_core::bail!(
294                    "Ring backend requires world_size to be a power of 2, got {}",
295                    config.world_size
296                );
297            }
298            Ok(Self { config })
299        }
300
301        pub fn rank(&self) -> usize {
302            self.config.rank
303        }
304
305        pub fn world_size(&self) -> usize {
306            self.config.world_size
307        }
308
309        pub fn config(&self) -> &RingConfig {
310            &self.config
311        }
312    }
313}
314
315// Dummy backend implementation
316mod dummy {
317    use candle_core::{Device, Result};
318
319    #[derive(Debug)]
320    pub struct DummyComm;
321
322    impl DummyComm {
323        pub fn from_device(
324            _id: super::Id,
325            _dev: &Device,
326            _rank: usize,
327            _world_size: usize,
328        ) -> Result<Self> {
329            Ok(Self)
330        }
331
332        pub fn rank(&self) -> usize {
333            0
334        }
335
336        pub fn world_size(&self) -> usize {
337            1
338        }
339    }
340}
341
342// Unified operations that work with the Comm enum
343#[derive(Clone, Debug)]
344pub struct SumAllReduce {
345    #[cfg(all(feature = "cuda", feature = "nccl"))]
346    nccl: Option<nccl_ops::SumAllReduce>,
347    #[cfg(feature = "ring")]
348    ring: Option<ring_ops::SumAllReduce>,
349    dummy: Option<dummy_ops::SumAllReduce>,
350}
351
352impl SumAllReduce {
353    pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
354        match &**comm {
355            #[cfg(all(feature = "cuda", feature = "nccl"))]
356            Comm::Nccl(_) => Self {
357                #[cfg(all(feature = "cuda", feature = "nccl"))]
358                nccl: Some(nccl_ops::SumAllReduce::new(comm)),
359                #[cfg(feature = "ring")]
360                ring: None,
361                dummy: None,
362            },
363            #[cfg(feature = "ring")]
364            Comm::Ring(_) => Self {
365                #[cfg(all(feature = "cuda", feature = "nccl"))]
366                nccl: None,
367                #[cfg(feature = "ring")]
368                ring: Some(ring_ops::SumAllReduce::new(comm)),
369                dummy: None,
370            },
371            Comm::Dummy(_) => Self {
372                #[cfg(all(feature = "cuda", feature = "nccl"))]
373                nccl: None,
374                #[cfg(feature = "ring")]
375                ring: None,
376                dummy: Some(dummy_ops::SumAllReduce::new(comm)),
377            },
378        }
379    }
380
381    pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
382        #[cfg(all(feature = "cuda", feature = "nccl"))]
383        if let Some(ref nccl) = self.nccl {
384            return nccl.sum_all_reduce(xs);
385        }
386        #[cfg(feature = "ring")]
387        if let Some(ref ring) = self.ring {
388            return ring.sum_all_reduce(xs);
389        }
390        if let Some(ref dummy) = self.dummy {
391            return dummy.sum_all_reduce(xs);
392        }
393        candle_core::bail!("No valid SumAllReduce implementation available")
394    }
395}
396
397#[derive(Clone, Debug)]
398pub struct AllGather {
399    #[cfg(all(feature = "cuda", feature = "nccl"))]
400    nccl: Option<nccl_ops::AllGather>,
401    #[cfg(feature = "ring")]
402    ring: Option<ring_ops::AllGather>,
403    dummy: Option<dummy_ops::AllGather>,
404}
405
406impl AllGather {
407    pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
408        match &**comm {
409            #[cfg(all(feature = "cuda", feature = "nccl"))]
410            Comm::Nccl(_) => Self {
411                #[cfg(all(feature = "cuda", feature = "nccl"))]
412                nccl: Some(nccl_ops::AllGather::new(comm, dim)),
413                #[cfg(feature = "ring")]
414                ring: None,
415                dummy: None,
416            },
417            #[cfg(feature = "ring")]
418            Comm::Ring(_) => Self {
419                #[cfg(all(feature = "cuda", feature = "nccl"))]
420                nccl: None,
421                #[cfg(feature = "ring")]
422                ring: Some(ring_ops::AllGather::new(comm, dim)),
423                dummy: None,
424            },
425            Comm::Dummy(_) => Self {
426                #[cfg(all(feature = "cuda", feature = "nccl"))]
427                nccl: None,
428                #[cfg(feature = "ring")]
429                ring: None,
430                dummy: Some(dummy_ops::AllGather::new(comm, dim)),
431            },
432        }
433    }
434
435    pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
436        #[cfg(all(feature = "cuda", feature = "nccl"))]
437        if let Some(ref nccl) = self.nccl {
438            return nccl.all_gather(xs);
439        }
440        #[cfg(feature = "ring")]
441        if let Some(ref ring) = self.ring {
442            return ring.all_gather(xs);
443        }
444        if let Some(ref dummy) = self.dummy {
445            return dummy.all_gather(xs);
446        }
447        candle_core::bail!("No valid AllGather implementation available")
448    }
449}
450
451// Implementation modules for each backend
452#[cfg(all(feature = "cuda", feature = "nccl"))]
453mod nccl_ops {
454    use std::{fmt::Debug, sync::Arc};
455
456    use candle_core::{
457        backend::BackendStorage, cuda::cudarc, CpuStorage, CustomOp1, DType, Layout, Result, Shape,
458        Tensor,
459    };
460
461    #[derive(Clone, Debug)]
462    pub struct SumAllReduce {
463        comm: Arc<super::Comm>,
464    }
465
466    impl SumAllReduce {
467        pub fn new(comm: &Arc<super::Comm>) -> Self {
468            Self { comm: comm.clone() }
469        }
470    }
471
472    impl SumAllReduce {
473        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
474            xs.apply_op1_no_bwd(self)
475        }
476    }
477
478    impl CustomOp1 for SumAllReduce {
479        fn name(&self) -> &'static str {
480            "SumAllReduce"
481        }
482
483        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
484            candle_core::bail!("SumAllReduce is never used on cpu")
485        }
486
487        fn cuda_fwd(
488            &self,
489            s: &candle_core::CudaStorage,
490            l: &Layout,
491        ) -> Result<(candle_core::CudaStorage, Shape)> {
492            use cudarc::nccl::ReduceOp;
493            use half::{bf16, f16};
494
495            let elem_count = l.shape().elem_count();
496            let dev = s.device().clone();
497
498            match self.comm.as_ref() {
499                super::Comm::Nccl(nccl_comm) => {
500                    let dst = match s.dtype() {
501                        DType::BF16 => {
502                            let s = s.as_cuda_slice::<bf16>()?;
503                            let s = match l.contiguous_offsets() {
504                                Some((0, l)) if l == s.len() => s,
505                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
506                            };
507                            assert!(elem_count > 0);
508                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
509                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
510                            nccl_comm
511                                .inner()
512                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
513                                .map_err(candle_core::Error::debug)?;
514                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
515                        }
516                        DType::F16 => {
517                            let s = s.as_cuda_slice::<f16>()?;
518                            let s = match l.contiguous_offsets() {
519                                Some((0, l)) if l == s.len() => s,
520                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
521                            };
522                            assert!(elem_count > 0);
523                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
524                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
525                            nccl_comm
526                                .inner()
527                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
528                                .map_err(candle_core::Error::debug)?;
529                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
530                        }
531                        DType::F32 => {
532                            let s = s.as_cuda_slice::<f32>()?;
533                            let s = match l.contiguous_offsets() {
534                                Some((0, l)) if l == s.len() => s,
535                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
536                            };
537                            assert!(elem_count > 0);
538                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
539                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
540                            nccl_comm
541                                .inner()
542                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
543                                .map_err(candle_core::Error::debug)?;
544                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
545                        }
546                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
547                    };
548                    Ok((dst, l.shape().clone()))
549                }
550                _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
551            }
552        }
553    }
554
555    #[derive(Clone, Debug)]
556    pub struct AllGather {
557        comm: Arc<super::Comm>,
558        dim: usize,
559    }
560
561    impl AllGather {
562        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
563            Self {
564                comm: comm.clone(),
565                dim,
566            }
567        }
568    }
569
570    impl AllGather {
571        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
572            xs.apply_op1_no_bwd(self)
573        }
574    }
575
576    impl CustomOp1 for AllGather {
577        fn name(&self) -> &'static str {
578            "AllGather"
579        }
580
581        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
582            candle_core::bail!("AllGather is never used on cpu")
583        }
584
585        fn cuda_fwd(
586            &self,
587            s: &candle_core::CudaStorage,
588            l: &Layout,
589        ) -> Result<(candle_core::CudaStorage, Shape)> {
590            use half::{bf16, f16};
591
592            let mut out_shape = l.shape().dims().to_vec();
593            out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
594            let out_shape = Shape::from(out_shape);
595
596            let elem_count = out_shape.elem_count();
597            let dev = s.device().clone();
598
599            match self.comm.as_ref() {
600                super::Comm::Nccl(nccl_comm) => {
601                    let dst = match s.dtype() {
602                        DType::BF16 => {
603                            let s = s.as_cuda_slice::<bf16>()?;
604                            let s = match l.contiguous_offsets() {
605                                Some((0, l)) if l == s.len() => s,
606                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
607                            };
608                            assert!(elem_count > 0);
609                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
610                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
611                            nccl_comm
612                                .inner()
613                                .all_gather(s, &mut dst)
614                                .map_err(candle_core::Error::debug)?;
615                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
616                        }
617                        DType::F16 => {
618                            let s = s.as_cuda_slice::<f16>()?;
619                            let s = match l.contiguous_offsets() {
620                                Some((0, l)) if l == s.len() => s,
621                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
622                            };
623                            assert!(elem_count > 0);
624                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
625                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
626                            nccl_comm
627                                .inner()
628                                .all_gather(s, &mut dst)
629                                .map_err(candle_core::Error::debug)?;
630                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
631                        }
632                        DType::F32 => {
633                            let s = s.as_cuda_slice::<f32>()?;
634                            let s = match l.contiguous_offsets() {
635                                Some((0, l)) if l == s.len() => s,
636                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
637                            };
638                            assert!(elem_count > 0);
639                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
640                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
641                            nccl_comm
642                                .inner()
643                                .all_gather(s, &mut dst)
644                                .map_err(candle_core::Error::debug)?;
645                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
646                        }
647                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
648                    };
649                    Ok((dst, out_shape))
650                }
651                _ => candle_core::bail!("AllGather requires NCCL backend"),
652            }
653        }
654    }
655}
656
657// Ring operations
658#[cfg(feature = "ring")]
659mod ring_ops {
660    use std::{
661        collections::HashMap,
662        sync::{Arc, Mutex, OnceLock},
663        time::{Duration, Instant},
664    };
665
666    use std::io::{Read, Write};
667    use std::net::{TcpListener, TcpStream};
668
669    // Friendly aliases to tame type complexity.
670    type SharedTcpStream = Arc<Mutex<TcpStream>>;
671    type LeftRight = (SharedTcpStream, SharedTcpStream);
672
673    use candle_core::{
674        backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
675    };
676
677    use super::RingConfig;
678
679    // Lazily–initialized pair of TCP streams shared by every ring‑based collective op
680    static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
681
682    fn get_ring_streams(config: &RingConfig) -> LeftRight {
683        LEFT_RIGHT_STREAMS
684            .get_or_init(|| {
685                let cur_port = config.port;
686
687                let right_ip = config.right_ip();
688                let right_port = config.right_port;
689
690                let left_listener =
691                    TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
692
693                let start = Instant::now();
694                // Connect to the right neighbor using the provided IP
695                let right = loop {
696                    match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
697                        Ok(s) => break s,
698                        Err(_) if start.elapsed() > Duration::from_secs(10) => {
699                            panic!("Failed to connect to right node due to 10-second timeout");
700                        }
701                        Err(_) => continue,
702                    }
703                };
704
705                // Accept connection from the left neighbour
706                let (left, _) = left_listener.accept().expect("accept left neighbour");
707
708                left.set_nodelay(true).unwrap();
709                left.set_nonblocking(false).unwrap();
710                right.set_nodelay(true).unwrap();
711                right.set_nonblocking(false).unwrap();
712
713                (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
714            })
715            .clone()
716    }
717
718    #[derive(Clone, Debug)]
719    pub struct SumAllReduce {
720        left: SharedTcpStream,
721        right: SharedTcpStream,
722        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
723    }
724
725    impl SumAllReduce {
726        pub fn new(comm: &Arc<super::Comm>) -> Self {
727            match &**comm {
728                super::Comm::Ring(ring_comm) => {
729                    let (left, right) = get_ring_streams(ring_comm.config());
730                    Self {
731                        left,
732                        right,
733                        buffers: Arc::new(Mutex::new(HashMap::new())),
734                    }
735                }
736                _ => panic!("SumAllReduce requires Ring backend"),
737            }
738        }
739
740        fn run<T: WithDType + Copy>(
741            &self,
742            x: &[T],
743            dims: &[usize],
744            device: &Device,
745        ) -> Result<Tensor> {
746            let nbytes = x.len() * std::mem::size_of_val(x);
747
748            // --- ping‑pong to overlap latency ---------------------------------------
749            // Clone the Arc references
750            let right = self.right.clone();
751            let left = self.left.clone();
752
753            // View the local slice as bytes that can be written on the wire.
754            let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
755
756            // Re‑use (or allocate) a receive buffer of identical size.
757            let mut buffers_guard = self.buffers.lock().map_err(|e| {
758                candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
759            })?;
760            let recv_buf = buffers_guard
761                .entry(nbytes)
762                .or_insert_with(|| vec![0u8; nbytes]);
763
764            // Lock both sockets once to avoid per-call mutex overhead.
765            let mut right_guard = right.lock().map_err(|e| {
766                candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
767            })?;
768            let mut left_guard = left.lock().map_err(|e| {
769                candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
770            })?;
771
772            // For the typical tensor size we see (~ 6 KiB) a single
773            // write/read pair is faster than chunking because the extra
774            // system‑call and loop overhead dominates.  Only fall back to the
775            // chunked "ping‑pong" pipeline for larger transfers.
776            if nbytes <= 8 * 1024 {
777                // --- fast path: one shot ------------------------------------
778                right_guard
779                    .write_all(data_bytes)
780                    .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
781
782                left_guard
783                    .read_exact(recv_buf)
784                    .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
785            } else {
786                // --- slow path: chunked ping‑pong ---------------------------
787                const CHUNK_SIZE: usize = 64 * 1024; // 64 KiB
788                let mut offset = 0;
789
790                while offset < nbytes {
791                    let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
792
793                    // send this chunk to the right neighbour
794                    right_guard
795                        .write_all(&data_bytes[offset..offset + len])
796                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
797
798                    // receive the matching chunk from the left neighbour
799                    left_guard
800                        .read_exact(&mut recv_buf[offset..offset + len])
801                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
802
803                    offset += len;
804                }
805            }
806
807            drop(left_guard);
808            drop(right_guard);
809
810            // -------------------------------------------------------------------------
811            // Interpret the received bytes as a slice of T and add element‑wise into x
812            let received: &[T] =
813                unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
814
815            Tensor::from_slice(received, dims, device)
816        }
817
818        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
819            let storage = xs.storage_and_layout().0;
820            let cpu_storage = match &*storage {
821                Storage::Cpu(storage) => storage,
822                Storage::Cuda(storage) => &storage.to_cpu_storage()?,
823                Storage::Metal(storage) => &storage.to_cpu_storage()?,
824            };
825
826            let delta = match cpu_storage {
827                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
828                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
829                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
830                _ => candle_core::bail!("Unsupported dtype for ring backend"),
831            };
832
833            xs + delta
834        }
835    }
836
837    #[derive(Clone, Debug)]
838    pub struct AllGather {
839        left: SharedTcpStream,
840        right: SharedTcpStream,
841        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
842        dim: usize,
843        world_size: usize,
844        rank: usize,
845    }
846
847    impl AllGather {
848        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
849            match &**comm {
850                super::Comm::Ring(ring_comm) => {
851                    let (left, right) = get_ring_streams(ring_comm.config());
852                    Self {
853                        left,
854                        right,
855                        buffers: Arc::new(Mutex::new(HashMap::new())),
856                        dim,
857                        world_size: ring_comm.world_size(),
858                        rank: ring_comm.rank(),
859                    }
860                }
861                _ => panic!("AllGather requires Ring backend"),
862            }
863        }
864
865        fn run<T: WithDType + Copy + Default>(
866            &self,
867            x: &[T],
868            dims: &[usize],
869            device: &Device,
870        ) -> Result<Tensor> {
871            // Validate gather dimension
872            if self.dim >= dims.len() {
873                candle_core::bail!(
874                    "AllGather: invalid dimension {} for tensor of rank {}",
875                    self.dim,
876                    dims.len()
877                );
878            }
879            let elem_cnt = x.len();
880            let nbytes = elem_cnt * std::mem::size_of_val(x);
881
882            // Prepare output buffer that will hold slices from every rank.
883            let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
884
885            // Copy this rank's slice into its final slot.
886            let start = self.rank * elem_cnt;
887            out[start..start + elem_cnt].copy_from_slice(x);
888
889            let right = self.right.clone();
890            let left = self.left.clone();
891            let mut send_piece: &[T] = x;
892
893            for step in 0..(self.world_size - 1) {
894                // ---------- send to the right ----------
895                let bytes =
896                    unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
897                {
898                    let mut rg = right.lock().map_err(|e| {
899                        candle_core::Error::msg(format!(
900                            "Failed to lock right stream mutex: {:?}",
901                            e
902                        ))
903                    })?;
904                    rg.write_all(bytes)
905                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
906                }
907
908                // ---------- receive from the left ----------
909                let mut bg = self.buffers.lock().map_err(|e| {
910                    candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
911                })?;
912                let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
913                {
914                    let mut lg = left.lock().map_err(|e| {
915                        candle_core::Error::msg(format!(
916                            "Failed to lock left stream mutex: {:?}",
917                            e
918                        ))
919                    })?;
920                    lg.read_exact(buf)
921                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
922                }
923                let recv_piece: &[T] =
924                    unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
925
926                // Determine which global rank the received slice came from.
927                let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
928                let dst = src_rank * elem_cnt;
929                out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
930
931                // Forward that slice in the next iteration.
932                send_piece = recv_piece;
933            }
934
935            let mut out_dims = dims.to_vec();
936            out_dims[self.dim] *= self.world_size;
937            Tensor::from_slice(&out, out_dims, device)
938        }
939
940        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
941            let storage = xs.storage_and_layout().0;
942            let cpu_storage = match &*storage {
943                Storage::Cpu(s) => s,
944                Storage::Cuda(s) => &s.to_cpu_storage()?,
945                Storage::Metal(s) => &s.to_cpu_storage()?,
946            };
947
948            match cpu_storage {
949                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
950                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
951                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
952                _ => candle_core::bail!("Unsupported dtype for ring backend"),
953            }
954        }
955    }
956}
957
958// Dummy operations
959mod dummy_ops {
960    use candle_core::{Result, Tensor};
961    use std::sync::Arc;
962
963    #[derive(Clone, Debug)]
964    pub struct SumAllReduce;
965
966    impl SumAllReduce {
967        pub fn new(_comm: &Arc<super::Comm>) -> Self {
968            Self
969        }
970
971        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
972            Ok(xs.clone())
973        }
974    }
975
976    #[derive(Clone, Debug)]
977    pub struct AllGather;
978
979    impl AllGather {
980        pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
981            Self
982        }
983
984        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
985            Ok(xs.clone())
986        }
987    }
988}