mistralrs_quant/distributed/
mod.rs

1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub use ops::{AllGather, Comm, Id, SumAllReduce};
5pub mod layers;
6pub mod socket;
7
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Deserialize, Serialize)]
11pub struct RingConfig {
12    master_ip: Option<String>,
13    pub master_port: u16,
14    pub port: u16,
15    pub right_port: u16,
16    right_ip: Option<String>,
17    pub rank: usize,
18    pub world_size: usize,
19}
20
21impl RingConfig {
22    /// Loads the ring backend config from a path at `RING_CONFIG`
23    pub fn load() -> Self {
24        let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
25        let config: RingConfig = serde_json::from_reader(
26            &File::open(config_json).expect("Could not access Ring config JSON"),
27        )
28        .expect("Invalid JSON config");
29
30        if config.master_ip.is_none() && !config.is_master_rank() {
31            panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
32        }
33        config
34    }
35
36    pub fn is_master_rank(&self) -> bool {
37        self.rank == 0
38    }
39
40    pub fn master_ip(&self) -> String {
41        self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
42    }
43
44    pub fn right_ip(&self) -> String {
45        self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
46    }
47}
48
49pub trait BarrierLike: Debug + Send + Sync {
50    fn wait(&self) -> Result<()>;
51}
52
53impl BarrierLike for Barrier {
54    fn wait(&self) -> Result<()> {
55        Barrier::wait(self);
56        Ok(())
57    }
58}
59
60pub fn get_global_tp_size_from_devices() -> Result<usize> {
61    #[cfg(feature = "cuda")]
62    {
63        use candle_core::cuda::WrapErr;
64        candle_core::cuda::cudarc::driver::result::device::get_count()
65            .w()
66            .map(|x| x as usize)
67    }
68    #[cfg(feature = "ring")]
69    {
70        let config = RingConfig::load();
71        Ok(config.world_size)
72    }
73
74    #[cfg(not(any(feature = "cuda", feature = "ring")))]
75    Ok(1)
76}
77
78pub fn use_nccl() -> bool {
79    (std::env::var("MISTRALRS_NO_NCCL").is_err()
80        || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
81        && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
82}
83
84#[cfg(all(feature = "cuda", feature = "nccl"))]
85mod ops {
86    use std::{fmt::Debug, ops::Deref, sync::Arc};
87
88    use candle_core::{
89        backend::BackendStorage, cuda::cudarc, cuda_backend::WrapErr, CpuStorage, CustomOp1, DType,
90        Device, Layout, Result, Shape, Tensor,
91    };
92
93    #[derive(Debug, Clone, Copy)]
94    pub struct Id(cudarc::nccl::Id);
95
96    impl Id {
97        pub fn new() -> Self {
98            let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
99            Self(id)
100        }
101
102        pub fn uninit(internal: [::core::ffi::c_char; 128usize]) -> Self {
103            Self(cudarc::nccl::Id::uninit(internal))
104        }
105
106        pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
107            self.0.internal()
108        }
109    }
110
111    #[derive(Debug)]
112    pub struct Comm {
113        comm: cudarc::nccl::Comm,
114    }
115
116    impl Comm {
117        pub fn from_device(id: Id, dev: &Device, rank: usize, world_size: usize) -> Result<Self> {
118            let device = dev.as_cuda_device()?.cuda_device();
119            assert_eq!(rank, device.ordinal());
120            Ok(Self {
121                comm: cudarc::nccl::Comm::from_rank(device, rank, world_size, id.0)
122                    .map_err(|e| e.0)
123                    .expect("Failed to create `Comm`, error code"),
124            })
125        }
126    }
127
128    /// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
129    unsafe impl Sync for Comm {}
130    unsafe impl Send for Comm {}
131
132    impl Deref for Comm {
133        type Target = cudarc::nccl::Comm;
134
135        fn deref(&self) -> &Self::Target {
136            &self.comm
137        }
138    }
139
140    #[derive(Clone, Debug)]
141    pub struct SumAllReduce {
142        comm: Arc<Comm>,
143    }
144
145    impl SumAllReduce {
146        pub fn new(comm: &Arc<Comm>) -> Self {
147            Self { comm: comm.clone() }
148        }
149    }
150
151    impl SumAllReduce {
152        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
153            xs.apply_op1_no_bwd(self)
154        }
155    }
156
157    impl CustomOp1 for SumAllReduce {
158        fn name(&self) -> &'static str {
159            "SumAllReduce"
160        }
161
162        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
163            candle_core::bail!("SumAllReduce is never used on cpu")
164        }
165
166        fn cuda_fwd(
167            &self,
168            s: &candle_core::CudaStorage,
169            l: &Layout,
170        ) -> Result<(candle_core::CudaStorage, Shape)> {
171            use cudarc::{driver::DeviceSlice, nccl::ReduceOp};
172            use half::{bf16, f16};
173
174            let elem_count = l.shape().elem_count();
175            let dev = s.device().clone();
176            let dst = match s.dtype() {
177                DType::BF16 => {
178                    let s = s.as_cuda_slice::<bf16>()?;
179                    let s = match l.contiguous_offsets() {
180                        Some((0, l)) if l == s.len() => s,
181                        Some(_) | None => candle_core::bail!("input has to be contiguous"),
182                    };
183                    assert_eq!(dev.ordinal(), self.comm.rank());
184                    assert!(elem_count > 0);
185                    let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
186                    self.comm
187                        .comm
188                        .all_reduce(s, &mut dst, &ReduceOp::Sum)
189                        .map_err(candle_core::Error::debug)?;
190                    candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
191                }
192                DType::F16 => {
193                    let s = s.as_cuda_slice::<f16>()?;
194                    let s = match l.contiguous_offsets() {
195                        Some((0, l)) if l == s.len() => s,
196                        Some(_) | None => candle_core::bail!("input has to be contiguous"),
197                    };
198                    let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
199                    self.comm
200                        .comm
201                        .all_reduce(s, &mut dst, &ReduceOp::Sum)
202                        .map_err(candle_core::Error::debug)?;
203                    candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
204                }
205                DType::F32 => {
206                    let s = s.as_cuda_slice::<f32>()?;
207                    let s = match l.contiguous_offsets() {
208                        Some((0, l)) if l == s.len() => s,
209                        Some(_) | None => candle_core::bail!("input has to be contiguous"),
210                    };
211                    let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
212                    self.comm
213                        .comm
214                        .all_reduce(s, &mut dst, &ReduceOp::Sum)
215                        .map_err(candle_core::Error::debug)?;
216                    candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
217                }
218                dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
219            };
220            Ok((dst, l.shape().clone()))
221        }
222    }
223
224    #[derive(Clone, Debug)]
225    pub struct AllGather {
226        comm: Arc<Comm>,
227        dim: usize,
228    }
229
230    impl AllGather {
231        pub fn new(comm: &Arc<Comm>, dim: usize) -> Self {
232            Self {
233                comm: comm.clone(),
234                dim,
235            }
236        }
237    }
238
239    impl AllGather {
240        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
241            xs.apply_op1_no_bwd(self)
242        }
243    }
244
245    impl CustomOp1 for AllGather {
246        fn name(&self) -> &'static str {
247            "AllGather"
248        }
249
250        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
251            candle_core::bail!("AllGather is never used on cpu")
252        }
253
254        fn cuda_fwd(
255            &self,
256            s: &candle_core::CudaStorage,
257            l: &Layout,
258        ) -> Result<(candle_core::CudaStorage, Shape)> {
259            use cudarc::driver::DeviceSlice;
260            use half::{bf16, f16};
261
262            let mut out_shape = l.shape().dims().to_vec();
263            out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
264            let out_shape = Shape::from(out_shape);
265
266            let elem_count = out_shape.elem_count();
267            let dev = s.device().clone();
268            let dst = match s.dtype() {
269                DType::BF16 => {
270                    let s = s.as_cuda_slice::<bf16>()?;
271                    let s = match l.contiguous_offsets() {
272                        Some((0, l)) if l == s.len() => s,
273                        Some(_) | None => candle_core::bail!("input has to be contiguous"),
274                    };
275                    assert_eq!(dev.ordinal(), self.comm.rank());
276                    assert!(elem_count > 0);
277                    let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
278                    self.comm
279                        .comm
280                        .all_gather(s, &mut dst)
281                        .map_err(candle_core::Error::debug)?;
282                    candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
283                }
284                DType::F16 => {
285                    let s = s.as_cuda_slice::<f16>()?;
286                    let s = match l.contiguous_offsets() {
287                        Some((0, l)) if l == s.len() => s,
288                        Some(_) | None => candle_core::bail!("input has to be contiguous"),
289                    };
290                    let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
291                    self.comm
292                        .comm
293                        .all_gather(s, &mut dst)
294                        .map_err(candle_core::Error::debug)?;
295                    candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
296                }
297                DType::F32 => {
298                    let s = s.as_cuda_slice::<f32>()?;
299                    let s = match l.contiguous_offsets() {
300                        Some((0, l)) if l == s.len() => s,
301                        Some(_) | None => candle_core::bail!("input has to be contiguous"),
302                    };
303                    let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
304                    self.comm
305                        .comm
306                        .all_gather(s, &mut dst)
307                        .map_err(candle_core::Error::debug)?;
308                    candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
309                }
310                dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
311            };
312            Ok((dst, out_shape))
313        }
314    }
315}
316
317#[cfg(feature = "ring")]
318mod ops {
319    use std::{
320        collections::HashMap,
321        fmt::Debug,
322        sync::{Arc, Mutex, OnceLock},
323        time::{Duration, Instant},
324    };
325
326    use std::io::{Read, Write};
327    use std::net::{TcpListener, TcpStream};
328
329    // Friendly aliases to tame type complexity.
330    type SharedTcpStream = Arc<Mutex<TcpStream>>;
331    type LeftRight = (SharedTcpStream, SharedTcpStream);
332
333    use candle_core::{
334        backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
335    };
336
337    use super::RingConfig;
338
339    // Lazily–initialized pair of TCP streams shared by every ring‑based collective op
340    static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
341
342    fn get_ring_streams(config: &RingConfig) -> LeftRight {
343        LEFT_RIGHT_STREAMS
344            .get_or_init(|| {
345                let cur_port = config.port;
346
347                let right_ip = config.right_ip();
348                let right_port = config.right_port;
349
350                let left_listener =
351                    TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
352
353                let start = Instant::now();
354                // Connect to the right neighbor using the provided IP
355                let right = loop {
356                    match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
357                        Ok(s) => break s,
358                        Err(_) if start.elapsed() > Duration::from_secs(10) => {
359                            panic!("Failed to connect to right node due to 10-second timeout");
360                        }
361                        Err(_) => continue,
362                    }
363                };
364
365                // Accept connection from the left neighbour
366                let (left, _) = left_listener.accept().expect("accept left neighbour");
367
368                left.set_nodelay(true).unwrap();
369                left.set_nonblocking(false).unwrap();
370                right.set_nodelay(true).unwrap();
371                right.set_nonblocking(false).unwrap();
372
373                (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
374            })
375            .clone()
376    }
377
378    #[derive(Debug, Clone, Copy)]
379    pub struct Id;
380
381    impl Default for Id {
382        fn default() -> Self {
383            Self::new()
384        }
385    }
386
387    impl Id {
388        pub fn new() -> Self {
389            Self
390        }
391
392        pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
393            Self
394        }
395
396        pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
397            static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
398            &ZEROED_ID
399        }
400    }
401
402    #[derive(Debug)]
403    pub struct Comm {
404        config: RingConfig,
405    }
406
407    impl Comm {
408        pub fn from_device(
409            _id: Id,
410            _dev: &Device,
411            _rank: usize,
412            _world_size: usize,
413        ) -> Result<Self> {
414            let config = RingConfig::load();
415            // Validate ring configuration
416            if config.world_size < 2 {
417                candle_core::bail!(
418                    "Ring backend requires world_size >= 2, got {}",
419                    config.world_size
420                );
421            }
422            if config.rank >= config.world_size {
423                candle_core::bail!(
424                    "Ring backend invalid config: rank {} >= world_size {}",
425                    config.rank,
426                    config.world_size
427                );
428            }
429            Ok(Self { config })
430        }
431
432        pub fn rank(&self) -> usize {
433            self.config.rank
434        }
435
436        pub fn world_size(&self) -> usize {
437            self.config.world_size
438        }
439    }
440
441    #[derive(Clone, Debug)]
442    pub struct SumAllReduce {
443        left: SharedTcpStream,
444        right: SharedTcpStream,
445        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
446    }
447
448    impl SumAllReduce {
449        pub fn new(comm: &Arc<Comm>) -> Self {
450            let (left, right) = get_ring_streams(&comm.config);
451            Self {
452                left,
453                right,
454                buffers: Arc::new(Mutex::new(HashMap::new())),
455            }
456        }
457
458        fn run<T: WithDType + Copy>(
459            &self,
460            x: &[T],
461            dims: &[usize],
462            device: &Device,
463        ) -> Result<Tensor> {
464            let nbytes = x.len() * std::mem::size_of_val(x);
465            // dbg!(nbytes);
466
467            // --- ping‑pong to overlap latency ---------------------------------------
468            // Clone the Arc references
469            let right = self.right.clone();
470            let left = self.left.clone();
471
472            // View the local slice as bytes that can be written on the wire.
473            let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
474
475            // Re‑use (or allocate) a receive buffer of identical size.
476            let mut buffers_guard = self.buffers.lock().map_err(|e| {
477                candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
478            })?;
479            let recv_buf = buffers_guard
480                .entry(nbytes)
481                .or_insert_with(|| vec![0u8; nbytes]);
482
483            // Lock both sockets once to avoid per-call mutex overhead.
484            let mut right_guard = right.lock().map_err(|e| {
485                candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
486            })?;
487            let mut left_guard = left.lock().map_err(|e| {
488                candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
489            })?;
490
491            // For the typical tensor size we see (~ 6 KiB) a single
492            // write/read pair is faster than chunking because the extra
493            // system‑call and loop overhead dominates.  Only fall back to the
494            // chunked “ping‑pong” pipeline for larger transfers.
495            if nbytes <= 8 * 1024 {
496                // --- fast path: one shot ------------------------------------
497                right_guard
498                    .write_all(data_bytes)
499                    .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
500
501                left_guard
502                    .read_exact(recv_buf)
503                    .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
504            } else {
505                // --- slow path: chunked ping‑pong ---------------------------
506                const CHUNK_SIZE: usize = 64 * 1024; // 64 KiB
507                let mut offset = 0;
508
509                while offset < nbytes {
510                    let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
511
512                    // send this chunk to the right neighbour
513                    right_guard
514                        .write_all(&data_bytes[offset..offset + len])
515                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
516
517                    // receive the matching chunk from the left neighbour
518                    left_guard
519                        .read_exact(&mut recv_buf[offset..offset + len])
520                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
521
522                    offset += len;
523                }
524            }
525
526            drop(left_guard);
527            drop(right_guard);
528            // drop(buffers_guard);
529
530            // -------------------------------------------------------------------------
531            // Interpret the received bytes as a slice of T and add element‑wise into x
532            let received: &[T] =
533                unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
534
535            Tensor::from_slice(received, dims, device)
536        }
537    }
538
539    impl SumAllReduce {
540        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
541            let storage = xs.storage_and_layout().0;
542            let cpu_storage = match &*storage {
543                Storage::Cpu(storage) => storage,
544                Storage::Cuda(storage) => &storage.to_cpu_storage()?,
545                Storage::Metal(storage) => &storage.to_cpu_storage()?,
546            };
547
548            let delta = match cpu_storage {
549                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
550                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
551                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
552                _ => candle_core::bail!("Unsupported dtype for ring backend"),
553            };
554
555            xs + delta
556        }
557    }
558
559    #[derive(Clone, Debug)]
560    pub struct AllGather {
561        left: SharedTcpStream,
562        right: SharedTcpStream,
563        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
564        dim: usize,
565        world_size: usize,
566        rank: usize,
567    }
568
569    impl AllGather {
570        pub fn new(comm: &Arc<Comm>, dim: usize) -> Self {
571            let (left, right) = get_ring_streams(&comm.config);
572            Self {
573                left,
574                right,
575                buffers: Arc::new(Mutex::new(HashMap::new())),
576                dim,
577                world_size: comm.world_size(),
578                rank: comm.rank(),
579            }
580        }
581
582        fn run<T: WithDType + Copy + Default>(
583            &self,
584            x: &[T],
585            dims: &[usize],
586            device: &Device,
587        ) -> Result<Tensor> {
588            // Validate gather dimension
589            if self.dim >= dims.len() {
590                candle_core::bail!(
591                    "AllGather: invalid dimension {} for tensor of rank {}",
592                    self.dim,
593                    dims.len()
594                );
595            }
596            let elem_cnt = x.len();
597            let nbytes = elem_cnt * std::mem::size_of_val(x);
598
599            // Prepare output buffer that will hold slices from every rank.
600            let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
601
602            // Copy this rank’s slice into its final slot.
603            let start = self.rank * elem_cnt;
604            out[start..start + elem_cnt].copy_from_slice(x);
605
606            let right = self.right.clone();
607            let left = self.left.clone();
608            let mut send_piece: &[T] = x;
609
610            for step in 0..(self.world_size - 1) {
611                // ---------- send to the right ----------
612                let bytes =
613                    unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
614                {
615                    let mut rg = right.lock().map_err(|e| {
616                        candle_core::Error::msg(format!(
617                            "Failed to lock right stream mutex: {:?}",
618                            e
619                        ))
620                    })?;
621                    rg.write_all(bytes)
622                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
623                }
624
625                // ---------- receive from the left ----------
626                let mut bg = self.buffers.lock().map_err(|e| {
627                    candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
628                })?;
629                let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
630                {
631                    let mut lg = left.lock().map_err(|e| {
632                        candle_core::Error::msg(format!(
633                            "Failed to lock left stream mutex: {:?}",
634                            e
635                        ))
636                    })?;
637                    lg.read_exact(buf)
638                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
639                }
640                let recv_piece: &[T] =
641                    unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
642
643                // Determine which global rank the received slice came from.
644                let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
645                let dst = src_rank * elem_cnt;
646                out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
647
648                // Forward that slice in the next iteration.
649                send_piece = recv_piece;
650            }
651
652            let mut out_dims = dims.to_vec();
653            out_dims[self.dim] *= self.world_size;
654            Tensor::from_slice(&out, out_dims, device)
655        }
656
657        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
658            let storage = xs.storage_and_layout().0;
659            let cpu_storage = match &*storage {
660                Storage::Cpu(s) => s,
661                Storage::Cuda(s) => &s.to_cpu_storage()?,
662                Storage::Metal(s) => &s.to_cpu_storage()?,
663            };
664
665            match cpu_storage {
666                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
667                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
668                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
669                _ => candle_core::bail!("Unsupported dtype for ring backend"),
670            }
671        }
672    }
673}
674
675#[cfg(not(any(all(feature = "cuda", feature = "nccl"), feature = "ring")))]
676mod ops {
677    use std::sync::Arc;
678
679    use candle_core::{Device, Result, Tensor};
680
681    #[derive(Debug, Clone, Copy)]
682    pub struct Id;
683
684    impl Default for Id {
685        fn default() -> Self {
686            Self::new()
687        }
688    }
689
690    impl Id {
691        pub fn new() -> Self {
692            Self
693        }
694
695        pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
696            Self
697        }
698
699        pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
700            static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
701            &ZEROED_ID
702        }
703    }
704
705    #[derive(Debug)]
706    pub struct Comm;
707
708    impl Comm {
709        pub fn from_device(
710            _id: Id,
711            _dev: &Device,
712            _rank: usize,
713            _world_size: usize,
714        ) -> Result<Self> {
715            Ok(Self)
716        }
717
718        pub fn rank(&self) -> usize {
719            0
720        }
721
722        pub fn world_size(&self) -> usize {
723            1
724        }
725    }
726
727    #[derive(Clone, Debug)]
728    pub struct SumAllReduce;
729
730    impl SumAllReduce {
731        pub fn new(_comm: &Arc<Comm>) -> Self {
732            Self
733        }
734    }
735
736    impl SumAllReduce {
737        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
738            Ok(xs.clone())
739        }
740    }
741
742    #[derive(Clone, Debug)]
743    pub struct AllGather;
744
745    impl AllGather {
746        pub fn new(_comm: &Arc<Comm>, _dim: usize) -> Self {
747            Self
748        }
749    }
750
751    impl AllGather {
752        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
753            Ok(xs.clone())
754        }
755    }
756}