mistralrs_quant/distributed/
mod.rs

1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11    master_ip: Option<String>,
12    pub master_port: u16,
13    pub port: u16,
14    pub right_port: u16,
15    right_ip: Option<String>,
16    pub rank: usize,
17    pub world_size: usize,
18}
19
20impl RingConfig {
21    /// Loads the ring backend config from a path at `RING_CONFIG`
22    pub fn load() -> Self {
23        let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24        let config: RingConfig = serde_json::from_reader(
25            &File::open(config_json).expect("Could not access Ring config JSON"),
26        )
27        .expect("Invalid JSON config");
28
29        if config.master_ip.is_none() && !config.is_master_rank() {
30            panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31        }
32        config
33    }
34
35    pub fn is_master_rank(&self) -> bool {
36        self.rank == 0
37    }
38
39    pub fn master_ip(&self) -> String {
40        self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41    }
42
43    pub fn right_ip(&self) -> String {
44        self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45    }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49    fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53    fn wait(&self) -> Result<()> {
54        Barrier::wait(self);
55        Ok(())
56    }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60    #[cfg(feature = "cuda")]
61    {
62        use candle_core::cuda::WrapErr;
63        candle_core::cuda::cudarc::driver::result::device::get_count()
64            .w()
65            .map(|x| x as usize)
66    }
67    #[cfg(feature = "ring")]
68    {
69        let config = RingConfig::load();
70        Ok(config.world_size)
71    }
72
73    #[cfg(not(any(feature = "cuda", feature = "ring")))]
74    Ok(1)
75}
76
77pub fn use_nccl() -> bool {
78    (std::env::var("MISTRALRS_NO_NCCL").is_err()
79        || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
80        && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
81}
82
83// Unified Comm enum
84#[derive(Debug)]
85pub enum Comm {
86    #[cfg(all(feature = "cuda", feature = "nccl"))]
87    Nccl(nccl::NcclComm),
88    #[cfg(feature = "ring")]
89    Ring(ring::RingComm),
90    Dummy(dummy::DummyComm),
91}
92
93impl Comm {
94    pub fn from_device(
95        id: Id,
96        dev: &candle_core::Device,
97        rank: usize,
98        world_size: usize,
99    ) -> Result<Self> {
100        #[cfg(all(feature = "cuda", feature = "nccl"))]
101        if use_nccl() {
102            return Ok(Self::Nccl(nccl::NcclComm::from_device(
103                id, dev, rank, world_size,
104            )?));
105        }
106
107        #[cfg(feature = "ring")]
108        {
109            return Ok(Self::Ring(ring::RingComm::from_device(
110                id, dev, rank, world_size,
111            )?));
112        }
113
114        #[allow(unreachable_code)]
115        Ok(Self::Dummy(dummy::DummyComm::from_device(
116            id, dev, rank, world_size,
117        )?))
118    }
119
120    pub fn rank(&self) -> usize {
121        match self {
122            #[cfg(all(feature = "cuda", feature = "nccl"))]
123            Self::Nccl(comm) => comm.rank(),
124            #[cfg(feature = "ring")]
125            Self::Ring(comm) => comm.rank(),
126            Self::Dummy(comm) => comm.rank(),
127        }
128    }
129
130    pub fn world_size(&self) -> usize {
131        match self {
132            #[cfg(all(feature = "cuda", feature = "nccl"))]
133            Self::Nccl(comm) => comm.world_size(),
134            #[cfg(feature = "ring")]
135            Self::Ring(comm) => comm.world_size(),
136            Self::Dummy(comm) => comm.world_size(),
137        }
138    }
139}
140
141// Unified Id enum
142#[derive(Debug, Clone, Copy)]
143pub enum Id {
144    #[cfg(all(feature = "cuda", feature = "nccl"))]
145    Nccl(cudarc::nccl::Id),
146    Dummy,
147}
148
149impl Id {
150    pub fn new() -> Self {
151        #[cfg(all(feature = "cuda", feature = "nccl"))]
152        if use_nccl() {
153            let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
154            return Self::Nccl(id);
155        }
156        Self::Dummy
157    }
158
159    pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
160        #[cfg(all(feature = "cuda", feature = "nccl"))]
161        if use_nccl() {
162            return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
163        }
164        Self::Dummy
165    }
166
167    pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
168        match self {
169            #[cfg(all(feature = "cuda", feature = "nccl"))]
170            Self::Nccl(id) => id.internal(),
171            Self::Dummy => {
172                static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
173                &ZEROED_ID
174            }
175        }
176    }
177}
178
179impl Default for Id {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185#[cfg(all(feature = "cuda", feature = "nccl"))]
186use candle_core::cuda::cudarc;
187
188// NCCL backend implementation
189#[cfg(all(feature = "cuda", feature = "nccl"))]
190mod nccl {
191    use candle_core::{cuda::cudarc, Device, Result};
192
193    #[derive(Debug)]
194    pub struct NcclComm {
195        comm: cudarc::nccl::Comm,
196    }
197
198    impl NcclComm {
199        pub fn from_device(
200            id: super::Id,
201            dev: &Device,
202            rank: usize,
203            world_size: usize,
204        ) -> Result<Self> {
205            if !super::use_nccl() {
206                candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
207            }
208            if !world_size.is_power_of_two() {
209                candle_core::bail!(
210                    "NCCL backend requires world_size to be a power of 2, got {}",
211                    world_size
212                );
213            }
214            let stream = dev.as_cuda_device()?.cuda_stream();
215            assert_eq!(rank, stream.context().ordinal());
216            let nccl_id = match id {
217                super::Id::Nccl(id) => id,
218                _ => panic!("Expected NCCL Id variant"),
219            };
220            Ok(Self {
221                comm: cudarc::nccl::Comm::from_rank(stream, rank, world_size, nccl_id)
222                    .map_err(|e| e.0)
223                    .expect("Failed to create `Comm`, error code"),
224            })
225        }
226
227        pub fn rank(&self) -> usize {
228            self.comm.rank()
229        }
230
231        pub fn world_size(&self) -> usize {
232            self.comm.world_size()
233        }
234
235        pub fn inner(&self) -> &cudarc::nccl::Comm {
236            &self.comm
237        }
238    }
239
240    /// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
241    unsafe impl Sync for NcclComm {}
242    unsafe impl Send for NcclComm {}
243}
244
245// Ring backend implementation
246#[cfg(feature = "ring")]
247mod ring {
248    use super::RingConfig;
249    use candle_core::{Device, Result};
250
251    #[derive(Debug)]
252    pub struct RingComm {
253        config: RingConfig,
254    }
255
256    impl RingComm {
257        pub fn from_device(
258            _id: super::Id,
259            _dev: &Device,
260            _rank: usize,
261            _world_size: usize,
262        ) -> Result<Self> {
263            let config = RingConfig::load();
264            // Validate ring configuration
265            if config.world_size < 2 {
266                candle_core::bail!(
267                    "Ring backend requires world_size >= 2, got {}",
268                    config.world_size
269                );
270            }
271            if config.rank >= config.world_size {
272                candle_core::bail!(
273                    "Ring backend invalid config: rank {} >= world_size {}",
274                    config.rank,
275                    config.world_size
276                );
277            }
278            if !config.world_size.is_power_of_two() {
279                candle_core::bail!(
280                    "Ring backend requires world_size to be a power of 2, got {}",
281                    config.world_size
282                );
283            }
284            Ok(Self { config })
285        }
286
287        pub fn rank(&self) -> usize {
288            self.config.rank
289        }
290
291        pub fn world_size(&self) -> usize {
292            self.config.world_size
293        }
294
295        pub fn config(&self) -> &RingConfig {
296            &self.config
297        }
298    }
299}
300
301// Dummy backend implementation
302mod dummy {
303    use candle_core::{Device, Result};
304
305    #[derive(Debug)]
306    pub struct DummyComm;
307
308    impl DummyComm {
309        pub fn from_device(
310            _id: super::Id,
311            _dev: &Device,
312            _rank: usize,
313            _world_size: usize,
314        ) -> Result<Self> {
315            Ok(Self)
316        }
317
318        pub fn rank(&self) -> usize {
319            0
320        }
321
322        pub fn world_size(&self) -> usize {
323            1
324        }
325    }
326}
327
328// Unified operations that work with the Comm enum
329#[derive(Clone, Debug)]
330pub struct SumAllReduce {
331    #[cfg(all(feature = "cuda", feature = "nccl"))]
332    nccl: Option<nccl_ops::SumAllReduce>,
333    #[cfg(feature = "ring")]
334    ring: Option<ring_ops::SumAllReduce>,
335    dummy: Option<dummy_ops::SumAllReduce>,
336}
337
338impl SumAllReduce {
339    pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
340        match &**comm {
341            #[cfg(all(feature = "cuda", feature = "nccl"))]
342            Comm::Nccl(_) => Self {
343                #[cfg(all(feature = "cuda", feature = "nccl"))]
344                nccl: Some(nccl_ops::SumAllReduce::new(comm)),
345                #[cfg(feature = "ring")]
346                ring: None,
347                dummy: None,
348            },
349            #[cfg(feature = "ring")]
350            Comm::Ring(_) => Self {
351                #[cfg(all(feature = "cuda", feature = "nccl"))]
352                nccl: None,
353                #[cfg(feature = "ring")]
354                ring: Some(ring_ops::SumAllReduce::new(comm)),
355                dummy: None,
356            },
357            Comm::Dummy(_) => Self {
358                #[cfg(all(feature = "cuda", feature = "nccl"))]
359                nccl: None,
360                #[cfg(feature = "ring")]
361                ring: None,
362                dummy: Some(dummy_ops::SumAllReduce::new(comm)),
363            },
364        }
365    }
366
367    pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
368        #[cfg(all(feature = "cuda", feature = "nccl"))]
369        if let Some(ref nccl) = self.nccl {
370            return nccl.sum_all_reduce(xs);
371        }
372        #[cfg(feature = "ring")]
373        if let Some(ref ring) = self.ring {
374            return ring.sum_all_reduce(xs);
375        }
376        if let Some(ref dummy) = self.dummy {
377            return dummy.sum_all_reduce(xs);
378        }
379        candle_core::bail!("No valid SumAllReduce implementation available")
380    }
381}
382
383#[derive(Clone, Debug)]
384pub struct AllGather {
385    #[cfg(all(feature = "cuda", feature = "nccl"))]
386    nccl: Option<nccl_ops::AllGather>,
387    #[cfg(feature = "ring")]
388    ring: Option<ring_ops::AllGather>,
389    dummy: Option<dummy_ops::AllGather>,
390}
391
392impl AllGather {
393    pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
394        match &**comm {
395            #[cfg(all(feature = "cuda", feature = "nccl"))]
396            Comm::Nccl(_) => Self {
397                #[cfg(all(feature = "cuda", feature = "nccl"))]
398                nccl: Some(nccl_ops::AllGather::new(comm, dim)),
399                #[cfg(feature = "ring")]
400                ring: None,
401                dummy: None,
402            },
403            #[cfg(feature = "ring")]
404            Comm::Ring(_) => Self {
405                #[cfg(all(feature = "cuda", feature = "nccl"))]
406                nccl: None,
407                #[cfg(feature = "ring")]
408                ring: Some(ring_ops::AllGather::new(comm, dim)),
409                dummy: None,
410            },
411            Comm::Dummy(_) => Self {
412                #[cfg(all(feature = "cuda", feature = "nccl"))]
413                nccl: None,
414                #[cfg(feature = "ring")]
415                ring: None,
416                dummy: Some(dummy_ops::AllGather::new(comm, dim)),
417            },
418        }
419    }
420
421    pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
422        #[cfg(all(feature = "cuda", feature = "nccl"))]
423        if let Some(ref nccl) = self.nccl {
424            return nccl.all_gather(xs);
425        }
426        #[cfg(feature = "ring")]
427        if let Some(ref ring) = self.ring {
428            return ring.all_gather(xs);
429        }
430        if let Some(ref dummy) = self.dummy {
431            return dummy.all_gather(xs);
432        }
433        candle_core::bail!("No valid AllGather implementation available")
434    }
435}
436
437// Implementation modules for each backend
438#[cfg(all(feature = "cuda", feature = "nccl"))]
439mod nccl_ops {
440    use std::{fmt::Debug, sync::Arc};
441
442    use candle_core::{
443        backend::BackendStorage, cuda::cudarc, CpuStorage, CustomOp1, DType, Layout, Result, Shape,
444        Tensor,
445    };
446
447    #[derive(Clone, Debug)]
448    pub struct SumAllReduce {
449        comm: Arc<super::Comm>,
450    }
451
452    impl SumAllReduce {
453        pub fn new(comm: &Arc<super::Comm>) -> Self {
454            Self { comm: comm.clone() }
455        }
456    }
457
458    impl SumAllReduce {
459        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
460            xs.apply_op1_no_bwd(self)
461        }
462    }
463
464    impl CustomOp1 for SumAllReduce {
465        fn name(&self) -> &'static str {
466            "SumAllReduce"
467        }
468
469        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
470            candle_core::bail!("SumAllReduce is never used on cpu")
471        }
472
473        fn cuda_fwd(
474            &self,
475            s: &candle_core::CudaStorage,
476            l: &Layout,
477        ) -> Result<(candle_core::CudaStorage, Shape)> {
478            use cudarc::nccl::ReduceOp;
479            use half::{bf16, f16};
480
481            let elem_count = l.shape().elem_count();
482            let dev = s.device().clone();
483
484            match self.comm.as_ref() {
485                super::Comm::Nccl(nccl_comm) => {
486                    let dst = match s.dtype() {
487                        DType::BF16 => {
488                            let s = s.as_cuda_slice::<bf16>()?;
489                            let s = match l.contiguous_offsets() {
490                                Some((0, l)) if l == s.len() => s,
491                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
492                            };
493                            assert!(elem_count > 0);
494                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
495                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
496                            nccl_comm
497                                .inner()
498                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
499                                .map_err(candle_core::Error::debug)?;
500                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
501                        }
502                        DType::F16 => {
503                            let s = s.as_cuda_slice::<f16>()?;
504                            let s = match l.contiguous_offsets() {
505                                Some((0, l)) if l == s.len() => s,
506                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
507                            };
508                            assert!(elem_count > 0);
509                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
510                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
511                            nccl_comm
512                                .inner()
513                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
514                                .map_err(candle_core::Error::debug)?;
515                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
516                        }
517                        DType::F32 => {
518                            let s = s.as_cuda_slice::<f32>()?;
519                            let s = match l.contiguous_offsets() {
520                                Some((0, l)) if l == s.len() => s,
521                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
522                            };
523                            assert!(elem_count > 0);
524                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
525                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
526                            nccl_comm
527                                .inner()
528                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
529                                .map_err(candle_core::Error::debug)?;
530                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
531                        }
532                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
533                    };
534                    Ok((dst, l.shape().clone()))
535                }
536                _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
537            }
538        }
539    }
540
541    #[derive(Clone, Debug)]
542    pub struct AllGather {
543        comm: Arc<super::Comm>,
544        dim: usize,
545    }
546
547    impl AllGather {
548        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
549            Self {
550                comm: comm.clone(),
551                dim,
552            }
553        }
554    }
555
556    impl AllGather {
557        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
558            xs.apply_op1_no_bwd(self)
559        }
560    }
561
562    impl CustomOp1 for AllGather {
563        fn name(&self) -> &'static str {
564            "AllGather"
565        }
566
567        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
568            candle_core::bail!("AllGather is never used on cpu")
569        }
570
571        fn cuda_fwd(
572            &self,
573            s: &candle_core::CudaStorage,
574            l: &Layout,
575        ) -> Result<(candle_core::CudaStorage, Shape)> {
576            use half::{bf16, f16};
577
578            let mut out_shape = l.shape().dims().to_vec();
579            out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
580            let out_shape = Shape::from(out_shape);
581
582            let elem_count = out_shape.elem_count();
583            let dev = s.device().clone();
584
585            match self.comm.as_ref() {
586                super::Comm::Nccl(nccl_comm) => {
587                    let dst = match s.dtype() {
588                        DType::BF16 => {
589                            let s = s.as_cuda_slice::<bf16>()?;
590                            let s = match l.contiguous_offsets() {
591                                Some((0, l)) if l == s.len() => s,
592                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
593                            };
594                            assert!(elem_count > 0);
595                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
596                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
597                            nccl_comm
598                                .inner()
599                                .all_gather(s, &mut dst)
600                                .map_err(candle_core::Error::debug)?;
601                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
602                        }
603                        DType::F16 => {
604                            let s = s.as_cuda_slice::<f16>()?;
605                            let s = match l.contiguous_offsets() {
606                                Some((0, l)) if l == s.len() => s,
607                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
608                            };
609                            assert!(elem_count > 0);
610                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
611                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
612                            nccl_comm
613                                .inner()
614                                .all_gather(s, &mut dst)
615                                .map_err(candle_core::Error::debug)?;
616                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
617                        }
618                        DType::F32 => {
619                            let s = s.as_cuda_slice::<f32>()?;
620                            let s = match l.contiguous_offsets() {
621                                Some((0, l)) if l == s.len() => s,
622                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
623                            };
624                            assert!(elem_count > 0);
625                            assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
626                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
627                            nccl_comm
628                                .inner()
629                                .all_gather(s, &mut dst)
630                                .map_err(candle_core::Error::debug)?;
631                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
632                        }
633                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
634                    };
635                    Ok((dst, out_shape))
636                }
637                _ => candle_core::bail!("AllGather requires NCCL backend"),
638            }
639        }
640    }
641}
642
643// Ring operations
644#[cfg(feature = "ring")]
645mod ring_ops {
646    use std::{
647        collections::HashMap,
648        sync::{Arc, Mutex, OnceLock},
649        time::{Duration, Instant},
650    };
651
652    use std::io::{Read, Write};
653    use std::net::{TcpListener, TcpStream};
654
655    // Friendly aliases to tame type complexity.
656    type SharedTcpStream = Arc<Mutex<TcpStream>>;
657    type LeftRight = (SharedTcpStream, SharedTcpStream);
658
659    use candle_core::{
660        backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
661    };
662
663    use super::RingConfig;
664
665    // Lazily–initialized pair of TCP streams shared by every ring‑based collective op
666    static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
667
668    fn get_ring_streams(config: &RingConfig) -> LeftRight {
669        LEFT_RIGHT_STREAMS
670            .get_or_init(|| {
671                let cur_port = config.port;
672
673                let right_ip = config.right_ip();
674                let right_port = config.right_port;
675
676                let left_listener =
677                    TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
678
679                let start = Instant::now();
680                // Connect to the right neighbor using the provided IP
681                let right = loop {
682                    match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
683                        Ok(s) => break s,
684                        Err(_) if start.elapsed() > Duration::from_secs(10) => {
685                            panic!("Failed to connect to right node due to 10-second timeout");
686                        }
687                        Err(_) => continue,
688                    }
689                };
690
691                // Accept connection from the left neighbour
692                let (left, _) = left_listener.accept().expect("accept left neighbour");
693
694                left.set_nodelay(true).unwrap();
695                left.set_nonblocking(false).unwrap();
696                right.set_nodelay(true).unwrap();
697                right.set_nonblocking(false).unwrap();
698
699                (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
700            })
701            .clone()
702    }
703
704    #[derive(Clone, Debug)]
705    pub struct SumAllReduce {
706        left: SharedTcpStream,
707        right: SharedTcpStream,
708        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
709    }
710
711    impl SumAllReduce {
712        pub fn new(comm: &Arc<super::Comm>) -> Self {
713            match &**comm {
714                super::Comm::Ring(ring_comm) => {
715                    let (left, right) = get_ring_streams(ring_comm.config());
716                    Self {
717                        left,
718                        right,
719                        buffers: Arc::new(Mutex::new(HashMap::new())),
720                    }
721                }
722                _ => panic!("SumAllReduce requires Ring backend"),
723            }
724        }
725
726        fn run<T: WithDType + Copy>(
727            &self,
728            x: &[T],
729            dims: &[usize],
730            device: &Device,
731        ) -> Result<Tensor> {
732            let nbytes = x.len() * std::mem::size_of_val(x);
733
734            // --- ping‑pong to overlap latency ---------------------------------------
735            // Clone the Arc references
736            let right = self.right.clone();
737            let left = self.left.clone();
738
739            // View the local slice as bytes that can be written on the wire.
740            let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
741
742            // Re‑use (or allocate) a receive buffer of identical size.
743            let mut buffers_guard = self.buffers.lock().map_err(|e| {
744                candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
745            })?;
746            let recv_buf = buffers_guard
747                .entry(nbytes)
748                .or_insert_with(|| vec![0u8; nbytes]);
749
750            // Lock both sockets once to avoid per-call mutex overhead.
751            let mut right_guard = right.lock().map_err(|e| {
752                candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
753            })?;
754            let mut left_guard = left.lock().map_err(|e| {
755                candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
756            })?;
757
758            // For the typical tensor size we see (~ 6 KiB) a single
759            // write/read pair is faster than chunking because the extra
760            // system‑call and loop overhead dominates.  Only fall back to the
761            // chunked "ping‑pong" pipeline for larger transfers.
762            if nbytes <= 8 * 1024 {
763                // --- fast path: one shot ------------------------------------
764                right_guard
765                    .write_all(data_bytes)
766                    .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
767
768                left_guard
769                    .read_exact(recv_buf)
770                    .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
771            } else {
772                // --- slow path: chunked ping‑pong ---------------------------
773                const CHUNK_SIZE: usize = 64 * 1024; // 64 KiB
774                let mut offset = 0;
775
776                while offset < nbytes {
777                    let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
778
779                    // send this chunk to the right neighbour
780                    right_guard
781                        .write_all(&data_bytes[offset..offset + len])
782                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
783
784                    // receive the matching chunk from the left neighbour
785                    left_guard
786                        .read_exact(&mut recv_buf[offset..offset + len])
787                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
788
789                    offset += len;
790                }
791            }
792
793            drop(left_guard);
794            drop(right_guard);
795
796            // -------------------------------------------------------------------------
797            // Interpret the received bytes as a slice of T and add element‑wise into x
798            let received: &[T] =
799                unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
800
801            Tensor::from_slice(received, dims, device)
802        }
803
804        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
805            let storage = xs.storage_and_layout().0;
806            let cpu_storage = match &*storage {
807                Storage::Cpu(storage) => storage,
808                Storage::Cuda(storage) => &storage.to_cpu_storage()?,
809                Storage::Metal(storage) => &storage.to_cpu_storage()?,
810            };
811
812            let delta = match cpu_storage {
813                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
814                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
815                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
816                _ => candle_core::bail!("Unsupported dtype for ring backend"),
817            };
818
819            xs + delta
820        }
821    }
822
823    #[derive(Clone, Debug)]
824    pub struct AllGather {
825        left: SharedTcpStream,
826        right: SharedTcpStream,
827        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
828        dim: usize,
829        world_size: usize,
830        rank: usize,
831    }
832
833    impl AllGather {
834        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
835            match &**comm {
836                super::Comm::Ring(ring_comm) => {
837                    let (left, right) = get_ring_streams(ring_comm.config());
838                    Self {
839                        left,
840                        right,
841                        buffers: Arc::new(Mutex::new(HashMap::new())),
842                        dim,
843                        world_size: ring_comm.world_size(),
844                        rank: ring_comm.rank(),
845                    }
846                }
847                _ => panic!("AllGather requires Ring backend"),
848            }
849        }
850
851        fn run<T: WithDType + Copy + Default>(
852            &self,
853            x: &[T],
854            dims: &[usize],
855            device: &Device,
856        ) -> Result<Tensor> {
857            // Validate gather dimension
858            if self.dim >= dims.len() {
859                candle_core::bail!(
860                    "AllGather: invalid dimension {} for tensor of rank {}",
861                    self.dim,
862                    dims.len()
863                );
864            }
865            let elem_cnt = x.len();
866            let nbytes = elem_cnt * std::mem::size_of_val(x);
867
868            // Prepare output buffer that will hold slices from every rank.
869            let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
870
871            // Copy this rank's slice into its final slot.
872            let start = self.rank * elem_cnt;
873            out[start..start + elem_cnt].copy_from_slice(x);
874
875            let right = self.right.clone();
876            let left = self.left.clone();
877            let mut send_piece: &[T] = x;
878
879            for step in 0..(self.world_size - 1) {
880                // ---------- send to the right ----------
881                let bytes =
882                    unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
883                {
884                    let mut rg = right.lock().map_err(|e| {
885                        candle_core::Error::msg(format!(
886                            "Failed to lock right stream mutex: {:?}",
887                            e
888                        ))
889                    })?;
890                    rg.write_all(bytes)
891                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
892                }
893
894                // ---------- receive from the left ----------
895                let mut bg = self.buffers.lock().map_err(|e| {
896                    candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
897                })?;
898                let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
899                {
900                    let mut lg = left.lock().map_err(|e| {
901                        candle_core::Error::msg(format!(
902                            "Failed to lock left stream mutex: {:?}",
903                            e
904                        ))
905                    })?;
906                    lg.read_exact(buf)
907                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
908                }
909                let recv_piece: &[T] =
910                    unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
911
912                // Determine which global rank the received slice came from.
913                let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
914                let dst = src_rank * elem_cnt;
915                out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
916
917                // Forward that slice in the next iteration.
918                send_piece = recv_piece;
919            }
920
921            let mut out_dims = dims.to_vec();
922            out_dims[self.dim] *= self.world_size;
923            Tensor::from_slice(&out, out_dims, device)
924        }
925
926        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
927            let storage = xs.storage_and_layout().0;
928            let cpu_storage = match &*storage {
929                Storage::Cpu(s) => s,
930                Storage::Cuda(s) => &s.to_cpu_storage()?,
931                Storage::Metal(s) => &s.to_cpu_storage()?,
932            };
933
934            match cpu_storage {
935                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
936                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
937                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
938                _ => candle_core::bail!("Unsupported dtype for ring backend"),
939            }
940        }
941    }
942}
943
944// Dummy operations
945mod dummy_ops {
946    use candle_core::{Result, Tensor};
947    use std::sync::Arc;
948
949    #[derive(Clone, Debug)]
950    pub struct SumAllReduce;
951
952    impl SumAllReduce {
953        pub fn new(_comm: &Arc<super::Comm>) -> Self {
954            Self
955        }
956
957        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
958            Ok(xs.clone())
959        }
960    }
961
962    #[derive(Clone, Debug)]
963    pub struct AllGather;
964
965    impl AllGather {
966        pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
967            Self
968        }
969
970        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
971            Ok(xs.clone())
972        }
973    }
974}