mistralrs_quant/distributed/
mod.rs

1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11    master_ip: Option<String>,
12    pub master_port: u16,
13    pub port: u16,
14    pub right_port: u16,
15    right_ip: Option<String>,
16    pub rank: usize,
17    pub world_size: usize,
18}
19
20impl RingConfig {
21    /// Loads the ring backend config from a path at `RING_CONFIG`
22    pub fn load() -> Self {
23        let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24        let config: RingConfig = serde_json::from_reader(
25            &File::open(config_json).expect("Could not access Ring config JSON"),
26        )
27        .expect("Invalid JSON config");
28
29        if config.master_ip.is_none() && !config.is_master_rank() {
30            panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31        }
32        config
33    }
34
35    pub fn is_master_rank(&self) -> bool {
36        self.rank == 0
37    }
38
39    pub fn master_ip(&self) -> String {
40        self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41    }
42
43    pub fn right_ip(&self) -> String {
44        self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45    }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49    fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53    fn wait(&self) -> Result<()> {
54        Barrier::wait(self);
55        Ok(())
56    }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60    #[cfg(feature = "cuda")]
61    {
62        use candle_core::cuda::WrapErr;
63        candle_core::cuda::cudarc::driver::result::device::get_count()
64            .w()
65            .map(|x| x as usize)
66    }
67    #[cfg(feature = "ring")]
68    {
69        let config = RingConfig::load();
70        Ok(config.world_size)
71    }
72
73    #[cfg(not(any(feature = "cuda", feature = "ring")))]
74    Ok(1)
75}
76
77pub fn use_nccl() -> bool {
78    (std::env::var("MISTRALRS_NO_NCCL").is_err()
79        || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
80        && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
81}
82
83// Unified Comm enum
84#[derive(Debug)]
85pub enum Comm {
86    #[cfg(all(feature = "cuda", feature = "nccl"))]
87    Nccl(nccl::NcclComm),
88    #[cfg(feature = "ring")]
89    Ring(ring::RingComm),
90    Dummy(dummy::DummyComm),
91}
92
93impl Comm {
94    pub fn from_device(
95        id: Id,
96        dev: &candle_core::Device,
97        rank: usize,
98        world_size: usize,
99    ) -> Result<Self> {
100        #[cfg(all(feature = "cuda", feature = "nccl"))]
101        if use_nccl() {
102            return Ok(Self::Nccl(nccl::NcclComm::from_device(
103                id, dev, rank, world_size,
104            )?));
105        }
106
107        #[cfg(feature = "ring")]
108        {
109            return Ok(Self::Ring(ring::RingComm::from_device(
110                id, dev, rank, world_size,
111            )?));
112        }
113
114        #[allow(unreachable_code)]
115        Ok(Self::Dummy(dummy::DummyComm::from_device(
116            id, dev, rank, world_size,
117        )?))
118    }
119
120    pub fn rank(&self) -> usize {
121        match self {
122            #[cfg(all(feature = "cuda", feature = "nccl"))]
123            Self::Nccl(comm) => comm.rank(),
124            #[cfg(feature = "ring")]
125            Self::Ring(comm) => comm.rank(),
126            Self::Dummy(comm) => comm.rank(),
127        }
128    }
129
130    pub fn world_size(&self) -> usize {
131        match self {
132            #[cfg(all(feature = "cuda", feature = "nccl"))]
133            Self::Nccl(comm) => comm.world_size(),
134            #[cfg(feature = "ring")]
135            Self::Ring(comm) => comm.world_size(),
136            Self::Dummy(comm) => comm.world_size(),
137        }
138    }
139}
140
141// Unified Id enum
142#[derive(Debug, Clone, Copy)]
143pub enum Id {
144    #[cfg(all(feature = "cuda", feature = "nccl"))]
145    Nccl(cudarc::nccl::Id),
146    Dummy,
147}
148
149impl Id {
150    pub fn new() -> Self {
151        #[cfg(all(feature = "cuda", feature = "nccl"))]
152        if use_nccl() {
153            let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
154            return Self::Nccl(id);
155        }
156        Self::Dummy
157    }
158
159    pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
160        #[cfg(all(feature = "cuda", feature = "nccl"))]
161        if use_nccl() {
162            return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
163        }
164        Self::Dummy
165    }
166
167    pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
168        match self {
169            #[cfg(all(feature = "cuda", feature = "nccl"))]
170            Self::Nccl(id) => id.internal(),
171            Self::Dummy => {
172                static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
173                &ZEROED_ID
174            }
175        }
176    }
177}
178
179impl Default for Id {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185#[cfg(all(feature = "cuda", feature = "nccl"))]
186use candle_core::cuda::cudarc;
187
188// NCCL backend implementation
189#[cfg(all(feature = "cuda", feature = "nccl"))]
190mod nccl {
191    use candle_core::{cuda::cudarc, Device, Result};
192
193    #[derive(Debug)]
194    pub struct NcclComm {
195        comm: cudarc::nccl::Comm,
196    }
197
198    impl NcclComm {
199        pub fn from_device(
200            id: super::Id,
201            dev: &Device,
202            rank: usize,
203            world_size: usize,
204        ) -> Result<Self> {
205            if !super::use_nccl() {
206                candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
207            }
208            if !world_size.is_power_of_two() {
209                candle_core::bail!(
210                    "NCCL backend requires world_size to be a power of 2, got {}",
211                    world_size
212                );
213            }
214            let device = dev.as_cuda_device()?.cuda_device();
215            assert_eq!(rank, device.ordinal());
216            let nccl_id = match id {
217                super::Id::Nccl(id) => id,
218                _ => panic!("Expected NCCL Id variant"),
219            };
220            Ok(Self {
221                comm: cudarc::nccl::Comm::from_rank(device, rank, world_size, nccl_id)
222                    .map_err(|e| e.0)
223                    .expect("Failed to create `Comm`, error code"),
224            })
225        }
226
227        pub fn rank(&self) -> usize {
228            self.comm.rank()
229        }
230
231        pub fn world_size(&self) -> usize {
232            self.comm.world_size()
233        }
234
235        pub fn inner(&self) -> &cudarc::nccl::Comm {
236            &self.comm
237        }
238    }
239
240    /// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
241    unsafe impl Sync for NcclComm {}
242    unsafe impl Send for NcclComm {}
243}
244
245// Ring backend implementation
246#[cfg(feature = "ring")]
247mod ring {
248    use super::RingConfig;
249    use candle_core::{Device, Result};
250
251    #[derive(Debug)]
252    pub struct RingComm {
253        config: RingConfig,
254    }
255
256    impl RingComm {
257        pub fn from_device(
258            _id: super::Id,
259            _dev: &Device,
260            _rank: usize,
261            _world_size: usize,
262        ) -> Result<Self> {
263            let config = RingConfig::load();
264            // Validate ring configuration
265            if config.world_size < 2 {
266                candle_core::bail!(
267                    "Ring backend requires world_size >= 2, got {}",
268                    config.world_size
269                );
270            }
271            if config.rank >= config.world_size {
272                candle_core::bail!(
273                    "Ring backend invalid config: rank {} >= world_size {}",
274                    config.rank,
275                    config.world_size
276                );
277            }
278            if !config.world_size.is_power_of_two() {
279                candle_core::bail!(
280                    "Ring backend requires world_size to be a power of 2, got {}",
281                    config.world_size
282                );
283            }
284            Ok(Self { config })
285        }
286
287        pub fn rank(&self) -> usize {
288            self.config.rank
289        }
290
291        pub fn world_size(&self) -> usize {
292            self.config.world_size
293        }
294
295        pub fn config(&self) -> &RingConfig {
296            &self.config
297        }
298    }
299}
300
301// Dummy backend implementation
302mod dummy {
303    use candle_core::{Device, Result};
304
305    #[derive(Debug)]
306    pub struct DummyComm;
307
308    impl DummyComm {
309        pub fn from_device(
310            _id: super::Id,
311            _dev: &Device,
312            _rank: usize,
313            _world_size: usize,
314        ) -> Result<Self> {
315            Ok(Self)
316        }
317
318        pub fn rank(&self) -> usize {
319            0
320        }
321
322        pub fn world_size(&self) -> usize {
323            1
324        }
325    }
326}
327
328// Unified operations that work with the Comm enum
329#[derive(Clone, Debug)]
330pub struct SumAllReduce {
331    #[cfg(all(feature = "cuda", feature = "nccl"))]
332    nccl: Option<nccl_ops::SumAllReduce>,
333    #[cfg(feature = "ring")]
334    ring: Option<ring_ops::SumAllReduce>,
335    dummy: Option<dummy_ops::SumAllReduce>,
336}
337
338impl SumAllReduce {
339    pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
340        match &**comm {
341            #[cfg(all(feature = "cuda", feature = "nccl"))]
342            Comm::Nccl(_) => Self {
343                #[cfg(all(feature = "cuda", feature = "nccl"))]
344                nccl: Some(nccl_ops::SumAllReduce::new(comm)),
345                #[cfg(feature = "ring")]
346                ring: None,
347                dummy: None,
348            },
349            #[cfg(feature = "ring")]
350            Comm::Ring(_) => Self {
351                #[cfg(all(feature = "cuda", feature = "nccl"))]
352                nccl: None,
353                #[cfg(feature = "ring")]
354                ring: Some(ring_ops::SumAllReduce::new(comm)),
355                dummy: None,
356            },
357            Comm::Dummy(_) => Self {
358                #[cfg(all(feature = "cuda", feature = "nccl"))]
359                nccl: None,
360                #[cfg(feature = "ring")]
361                ring: None,
362                dummy: Some(dummy_ops::SumAllReduce::new(comm)),
363            },
364        }
365    }
366
367    pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
368        #[cfg(all(feature = "cuda", feature = "nccl"))]
369        if let Some(ref nccl) = self.nccl {
370            return nccl.sum_all_reduce(xs);
371        }
372        #[cfg(feature = "ring")]
373        if let Some(ref ring) = self.ring {
374            return ring.sum_all_reduce(xs);
375        }
376        if let Some(ref dummy) = self.dummy {
377            return dummy.sum_all_reduce(xs);
378        }
379        candle_core::bail!("No valid SumAllReduce implementation available")
380    }
381}
382
383#[derive(Clone, Debug)]
384pub struct AllGather {
385    #[cfg(all(feature = "cuda", feature = "nccl"))]
386    nccl: Option<nccl_ops::AllGather>,
387    #[cfg(feature = "ring")]
388    ring: Option<ring_ops::AllGather>,
389    dummy: Option<dummy_ops::AllGather>,
390}
391
392impl AllGather {
393    pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
394        match &**comm {
395            #[cfg(all(feature = "cuda", feature = "nccl"))]
396            Comm::Nccl(_) => Self {
397                #[cfg(all(feature = "cuda", feature = "nccl"))]
398                nccl: Some(nccl_ops::AllGather::new(comm, dim)),
399                #[cfg(feature = "ring")]
400                ring: None,
401                dummy: None,
402            },
403            #[cfg(feature = "ring")]
404            Comm::Ring(_) => Self {
405                #[cfg(all(feature = "cuda", feature = "nccl"))]
406                nccl: None,
407                #[cfg(feature = "ring")]
408                ring: Some(ring_ops::AllGather::new(comm, dim)),
409                dummy: None,
410            },
411            Comm::Dummy(_) => Self {
412                #[cfg(all(feature = "cuda", feature = "nccl"))]
413                nccl: None,
414                #[cfg(feature = "ring")]
415                ring: None,
416                dummy: Some(dummy_ops::AllGather::new(comm, dim)),
417            },
418        }
419    }
420
421    pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
422        #[cfg(all(feature = "cuda", feature = "nccl"))]
423        if let Some(ref nccl) = self.nccl {
424            return nccl.all_gather(xs);
425        }
426        #[cfg(feature = "ring")]
427        if let Some(ref ring) = self.ring {
428            return ring.all_gather(xs);
429        }
430        if let Some(ref dummy) = self.dummy {
431            return dummy.all_gather(xs);
432        }
433        candle_core::bail!("No valid AllGather implementation available")
434    }
435}
436
437// Implementation modules for each backend
438#[cfg(all(feature = "cuda", feature = "nccl"))]
439mod nccl_ops {
440    use std::{fmt::Debug, sync::Arc};
441
442    use candle_core::{
443        backend::BackendStorage, cuda::cudarc, cuda_backend::WrapErr, CpuStorage, CustomOp1, DType,
444        Layout, Result, Shape, Tensor,
445    };
446
447    #[derive(Clone, Debug)]
448    pub struct SumAllReduce {
449        comm: Arc<super::Comm>,
450    }
451
452    impl SumAllReduce {
453        pub fn new(comm: &Arc<super::Comm>) -> Self {
454            Self { comm: comm.clone() }
455        }
456    }
457
458    impl SumAllReduce {
459        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
460            xs.apply_op1_no_bwd(self)
461        }
462    }
463
464    impl CustomOp1 for SumAllReduce {
465        fn name(&self) -> &'static str {
466            "SumAllReduce"
467        }
468
469        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
470            candle_core::bail!("SumAllReduce is never used on cpu")
471        }
472
473        fn cuda_fwd(
474            &self,
475            s: &candle_core::CudaStorage,
476            l: &Layout,
477        ) -> Result<(candle_core::CudaStorage, Shape)> {
478            use cudarc::{driver::DeviceSlice, nccl::ReduceOp};
479            use half::{bf16, f16};
480
481            let elem_count = l.shape().elem_count();
482            let dev = s.device().clone();
483
484            match self.comm.as_ref() {
485                super::Comm::Nccl(nccl_comm) => {
486                    let dst = match s.dtype() {
487                        DType::BF16 => {
488                            let s = s.as_cuda_slice::<bf16>()?;
489                            let s = match l.contiguous_offsets() {
490                                Some((0, l)) if l == s.len() => s,
491                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
492                            };
493                            assert_eq!(dev.ordinal(), nccl_comm.rank());
494                            assert!(elem_count > 0);
495                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
496                            nccl_comm
497                                .inner()
498                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
499                                .map_err(candle_core::Error::debug)?;
500                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
501                        }
502                        DType::F16 => {
503                            let s = s.as_cuda_slice::<f16>()?;
504                            let s = match l.contiguous_offsets() {
505                                Some((0, l)) if l == s.len() => s,
506                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
507                            };
508                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
509                            nccl_comm
510                                .inner()
511                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
512                                .map_err(candle_core::Error::debug)?;
513                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
514                        }
515                        DType::F32 => {
516                            let s = s.as_cuda_slice::<f32>()?;
517                            let s = match l.contiguous_offsets() {
518                                Some((0, l)) if l == s.len() => s,
519                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
520                            };
521                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
522                            nccl_comm
523                                .inner()
524                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
525                                .map_err(candle_core::Error::debug)?;
526                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
527                        }
528                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
529                    };
530                    Ok((dst, l.shape().clone()))
531                }
532                _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
533            }
534        }
535    }
536
537    #[derive(Clone, Debug)]
538    pub struct AllGather {
539        comm: Arc<super::Comm>,
540        dim: usize,
541    }
542
543    impl AllGather {
544        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
545            Self {
546                comm: comm.clone(),
547                dim,
548            }
549        }
550    }
551
552    impl AllGather {
553        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
554            xs.apply_op1_no_bwd(self)
555        }
556    }
557
558    impl CustomOp1 for AllGather {
559        fn name(&self) -> &'static str {
560            "AllGather"
561        }
562
563        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
564            candle_core::bail!("AllGather is never used on cpu")
565        }
566
567        fn cuda_fwd(
568            &self,
569            s: &candle_core::CudaStorage,
570            l: &Layout,
571        ) -> Result<(candle_core::CudaStorage, Shape)> {
572            use cudarc::driver::DeviceSlice;
573            use half::{bf16, f16};
574
575            let mut out_shape = l.shape().dims().to_vec();
576            out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
577            let out_shape = Shape::from(out_shape);
578
579            let elem_count = out_shape.elem_count();
580            let dev = s.device().clone();
581
582            match self.comm.as_ref() {
583                super::Comm::Nccl(nccl_comm) => {
584                    let dst = match s.dtype() {
585                        DType::BF16 => {
586                            let s = s.as_cuda_slice::<bf16>()?;
587                            let s = match l.contiguous_offsets() {
588                                Some((0, l)) if l == s.len() => s,
589                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
590                            };
591                            assert_eq!(dev.ordinal(), nccl_comm.rank());
592                            assert!(elem_count > 0);
593                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
594                            nccl_comm
595                                .inner()
596                                .all_gather(s, &mut dst)
597                                .map_err(candle_core::Error::debug)?;
598                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
599                        }
600                        DType::F16 => {
601                            let s = s.as_cuda_slice::<f16>()?;
602                            let s = match l.contiguous_offsets() {
603                                Some((0, l)) if l == s.len() => s,
604                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
605                            };
606                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
607                            nccl_comm
608                                .inner()
609                                .all_gather(s, &mut dst)
610                                .map_err(candle_core::Error::debug)?;
611                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
612                        }
613                        DType::F32 => {
614                            let s = s.as_cuda_slice::<f32>()?;
615                            let s = match l.contiguous_offsets() {
616                                Some((0, l)) if l == s.len() => s,
617                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
618                            };
619                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
620                            nccl_comm
621                                .inner()
622                                .all_gather(s, &mut dst)
623                                .map_err(candle_core::Error::debug)?;
624                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
625                        }
626                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
627                    };
628                    Ok((dst, out_shape))
629                }
630                _ => candle_core::bail!("AllGather requires NCCL backend"),
631            }
632        }
633    }
634}
635
636// Ring operations
637#[cfg(feature = "ring")]
638mod ring_ops {
639    use std::{
640        collections::HashMap,
641        sync::{Arc, Mutex, OnceLock},
642        time::{Duration, Instant},
643    };
644
645    use std::io::{Read, Write};
646    use std::net::{TcpListener, TcpStream};
647
648    // Friendly aliases to tame type complexity.
649    type SharedTcpStream = Arc<Mutex<TcpStream>>;
650    type LeftRight = (SharedTcpStream, SharedTcpStream);
651
652    use candle_core::{
653        backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
654    };
655
656    use super::RingConfig;
657
658    // Lazily–initialized pair of TCP streams shared by every ring‑based collective op
659    static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
660
661    fn get_ring_streams(config: &RingConfig) -> LeftRight {
662        LEFT_RIGHT_STREAMS
663            .get_or_init(|| {
664                let cur_port = config.port;
665
666                let right_ip = config.right_ip();
667                let right_port = config.right_port;
668
669                let left_listener =
670                    TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
671
672                let start = Instant::now();
673                // Connect to the right neighbor using the provided IP
674                let right = loop {
675                    match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
676                        Ok(s) => break s,
677                        Err(_) if start.elapsed() > Duration::from_secs(10) => {
678                            panic!("Failed to connect to right node due to 10-second timeout");
679                        }
680                        Err(_) => continue,
681                    }
682                };
683
684                // Accept connection from the left neighbour
685                let (left, _) = left_listener.accept().expect("accept left neighbour");
686
687                left.set_nodelay(true).unwrap();
688                left.set_nonblocking(false).unwrap();
689                right.set_nodelay(true).unwrap();
690                right.set_nonblocking(false).unwrap();
691
692                (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
693            })
694            .clone()
695    }
696
697    #[derive(Clone, Debug)]
698    pub struct SumAllReduce {
699        left: SharedTcpStream,
700        right: SharedTcpStream,
701        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
702    }
703
704    impl SumAllReduce {
705        pub fn new(comm: &Arc<super::Comm>) -> Self {
706            match &**comm {
707                super::Comm::Ring(ring_comm) => {
708                    let (left, right) = get_ring_streams(ring_comm.config());
709                    Self {
710                        left,
711                        right,
712                        buffers: Arc::new(Mutex::new(HashMap::new())),
713                    }
714                }
715                _ => panic!("SumAllReduce requires Ring backend"),
716            }
717        }
718
719        fn run<T: WithDType + Copy>(
720            &self,
721            x: &[T],
722            dims: &[usize],
723            device: &Device,
724        ) -> Result<Tensor> {
725            let nbytes = x.len() * std::mem::size_of_val(x);
726
727            // --- ping‑pong to overlap latency ---------------------------------------
728            // Clone the Arc references
729            let right = self.right.clone();
730            let left = self.left.clone();
731
732            // View the local slice as bytes that can be written on the wire.
733            let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
734
735            // Re‑use (or allocate) a receive buffer of identical size.
736            let mut buffers_guard = self.buffers.lock().map_err(|e| {
737                candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
738            })?;
739            let recv_buf = buffers_guard
740                .entry(nbytes)
741                .or_insert_with(|| vec![0u8; nbytes]);
742
743            // Lock both sockets once to avoid per-call mutex overhead.
744            let mut right_guard = right.lock().map_err(|e| {
745                candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
746            })?;
747            let mut left_guard = left.lock().map_err(|e| {
748                candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
749            })?;
750
751            // For the typical tensor size we see (~ 6 KiB) a single
752            // write/read pair is faster than chunking because the extra
753            // system‑call and loop overhead dominates.  Only fall back to the
754            // chunked "ping‑pong" pipeline for larger transfers.
755            if nbytes <= 8 * 1024 {
756                // --- fast path: one shot ------------------------------------
757                right_guard
758                    .write_all(data_bytes)
759                    .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
760
761                left_guard
762                    .read_exact(recv_buf)
763                    .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
764            } else {
765                // --- slow path: chunked ping‑pong ---------------------------
766                const CHUNK_SIZE: usize = 64 * 1024; // 64 KiB
767                let mut offset = 0;
768
769                while offset < nbytes {
770                    let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
771
772                    // send this chunk to the right neighbour
773                    right_guard
774                        .write_all(&data_bytes[offset..offset + len])
775                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
776
777                    // receive the matching chunk from the left neighbour
778                    left_guard
779                        .read_exact(&mut recv_buf[offset..offset + len])
780                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
781
782                    offset += len;
783                }
784            }
785
786            drop(left_guard);
787            drop(right_guard);
788
789            // -------------------------------------------------------------------------
790            // Interpret the received bytes as a slice of T and add element‑wise into x
791            let received: &[T] =
792                unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
793
794            Tensor::from_slice(received, dims, device)
795        }
796
797        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
798            let storage = xs.storage_and_layout().0;
799            let cpu_storage = match &*storage {
800                Storage::Cpu(storage) => storage,
801                Storage::Cuda(storage) => &storage.to_cpu_storage()?,
802                Storage::Metal(storage) => &storage.to_cpu_storage()?,
803            };
804
805            let delta = match cpu_storage {
806                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
807                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
808                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
809                _ => candle_core::bail!("Unsupported dtype for ring backend"),
810            };
811
812            xs + delta
813        }
814    }
815
816    #[derive(Clone, Debug)]
817    pub struct AllGather {
818        left: SharedTcpStream,
819        right: SharedTcpStream,
820        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
821        dim: usize,
822        world_size: usize,
823        rank: usize,
824    }
825
826    impl AllGather {
827        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
828            match &**comm {
829                super::Comm::Ring(ring_comm) => {
830                    let (left, right) = get_ring_streams(ring_comm.config());
831                    Self {
832                        left,
833                        right,
834                        buffers: Arc::new(Mutex::new(HashMap::new())),
835                        dim,
836                        world_size: ring_comm.world_size(),
837                        rank: ring_comm.rank(),
838                    }
839                }
840                _ => panic!("AllGather requires Ring backend"),
841            }
842        }
843
844        fn run<T: WithDType + Copy + Default>(
845            &self,
846            x: &[T],
847            dims: &[usize],
848            device: &Device,
849        ) -> Result<Tensor> {
850            // Validate gather dimension
851            if self.dim >= dims.len() {
852                candle_core::bail!(
853                    "AllGather: invalid dimension {} for tensor of rank {}",
854                    self.dim,
855                    dims.len()
856                );
857            }
858            let elem_cnt = x.len();
859            let nbytes = elem_cnt * std::mem::size_of_val(x);
860
861            // Prepare output buffer that will hold slices from every rank.
862            let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
863
864            // Copy this rank's slice into its final slot.
865            let start = self.rank * elem_cnt;
866            out[start..start + elem_cnt].copy_from_slice(x);
867
868            let right = self.right.clone();
869            let left = self.left.clone();
870            let mut send_piece: &[T] = x;
871
872            for step in 0..(self.world_size - 1) {
873                // ---------- send to the right ----------
874                let bytes =
875                    unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
876                {
877                    let mut rg = right.lock().map_err(|e| {
878                        candle_core::Error::msg(format!(
879                            "Failed to lock right stream mutex: {:?}",
880                            e
881                        ))
882                    })?;
883                    rg.write_all(bytes)
884                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
885                }
886
887                // ---------- receive from the left ----------
888                let mut bg = self.buffers.lock().map_err(|e| {
889                    candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
890                })?;
891                let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
892                {
893                    let mut lg = left.lock().map_err(|e| {
894                        candle_core::Error::msg(format!(
895                            "Failed to lock left stream mutex: {:?}",
896                            e
897                        ))
898                    })?;
899                    lg.read_exact(buf)
900                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
901                }
902                let recv_piece: &[T] =
903                    unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
904
905                // Determine which global rank the received slice came from.
906                let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
907                let dst = src_rank * elem_cnt;
908                out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
909
910                // Forward that slice in the next iteration.
911                send_piece = recv_piece;
912            }
913
914            let mut out_dims = dims.to_vec();
915            out_dims[self.dim] *= self.world_size;
916            Tensor::from_slice(&out, out_dims, device)
917        }
918
919        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
920            let storage = xs.storage_and_layout().0;
921            let cpu_storage = match &*storage {
922                Storage::Cpu(s) => s,
923                Storage::Cuda(s) => &s.to_cpu_storage()?,
924                Storage::Metal(s) => &s.to_cpu_storage()?,
925            };
926
927            match cpu_storage {
928                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
929                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
930                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
931                _ => candle_core::bail!("Unsupported dtype for ring backend"),
932            }
933        }
934    }
935}
936
937// Dummy operations
938mod dummy_ops {
939    use candle_core::{Result, Tensor};
940    use std::sync::Arc;
941
942    #[derive(Clone, Debug)]
943    pub struct SumAllReduce;
944
945    impl SumAllReduce {
946        pub fn new(_comm: &Arc<super::Comm>) -> Self {
947            Self
948        }
949
950        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
951            Ok(xs.clone())
952        }
953    }
954
955    #[derive(Clone, Debug)]
956    pub struct AllGather;
957
958    impl AllGather {
959        pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
960            Self
961        }
962
963        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
964            Ok(xs.clone())
965        }
966    }
967}