mistralrs_quant/distributed/
mod.rs

1use std::{fmt::Debug, sync::Barrier};
2
3use candle_core::{Result, Tensor};
4pub use ops::{Comm, Id, SumAllReduce};
5pub mod layers;
6pub mod socket;
7
8pub trait BarrierLike: Debug + Send + Sync {
9    fn wait(&self) -> Result<()>;
10}
11
12impl BarrierLike for Barrier {
13    fn wait(&self) -> Result<()> {
14        Barrier::wait(self);
15        Ok(())
16    }
17}
18
19pub fn get_global_tp_size_from_devices() -> Result<usize> {
20    #[cfg(feature = "cuda")]
21    {
22        use candle_core::cuda::WrapErr;
23        candle_core::cuda::cudarc::driver::result::device::get_count()
24            .w()
25            .map(|x| x as usize)
26    }
27    #[cfg(not(feature = "cuda"))]
28    Ok(1)
29}
30
31pub fn use_nccl() -> bool {
32    (std::env::var("MISTRALRS_NO_NCCL").is_err()
33        || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
34        && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
35}
36
37pub trait DistributedOperation {
38    fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor>;
39}
40
41#[cfg(all(feature = "cuda", feature = "nccl"))]
42mod ops {
43    use std::{fmt::Debug, ops::Deref, sync::Arc};
44
45    use candle_core::{
46        backend::BackendStorage, cuda::cudarc, cuda_backend::WrapErr, CpuStorage, CustomOp1, DType,
47        Device, Layout, Result, Shape, Tensor,
48    };
49
50    #[derive(Debug, Clone, Copy)]
51    pub struct Id(cudarc::nccl::Id);
52
53    impl Id {
54        pub fn new() -> Self {
55            let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
56            Self(id)
57        }
58
59        pub fn uninit(internal: [::core::ffi::c_char; 128usize]) -> Self {
60            Self(cudarc::nccl::Id::uninit(internal))
61        }
62
63        pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
64            self.0.internal()
65        }
66    }
67
68    #[derive(Debug)]
69    pub struct Comm {
70        comm: cudarc::nccl::Comm,
71    }
72
73    impl Comm {
74        pub fn from_device(id: Id, dev: &Device, rank: usize, world_size: usize) -> Result<Self> {
75            let device = dev.as_cuda_device()?.cuda_device();
76            assert_eq!(rank, device.ordinal());
77            Ok(Self {
78                comm: cudarc::nccl::Comm::from_rank(device, rank, world_size, id.0)
79                    .map_err(|e| e.0)
80                    .expect("Failed to create `Comm`, error code"),
81            })
82        }
83    }
84
85    /// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
86    unsafe impl Sync for Comm {}
87    unsafe impl Send for Comm {}
88
89    impl Deref for Comm {
90        type Target = cudarc::nccl::Comm;
91
92        fn deref(&self) -> &Self::Target {
93            &self.comm
94        }
95    }
96
97    #[derive(Clone, Debug)]
98    pub struct SumAllReduce {
99        comm: Arc<Comm>,
100    }
101
102    impl SumAllReduce {
103        pub fn new(comm: &Arc<Comm>) -> Self {
104            Self { comm: comm.clone() }
105        }
106    }
107
108    impl super::DistributedOperation for SumAllReduce {
109        fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
110            xs.apply_op1_no_bwd(self)
111        }
112    }
113
114    impl CustomOp1 for SumAllReduce {
115        fn name(&self) -> &'static str {
116            "SumAllReduce"
117        }
118
119        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
120            candle_core::bail!("SumAllReduce is never used on cpu")
121        }
122
123        fn cuda_fwd(
124            &self,
125            s: &candle_core::CudaStorage,
126            l: &Layout,
127        ) -> Result<(candle_core::CudaStorage, Shape)> {
128            use cudarc::{driver::DeviceSlice, nccl::ReduceOp};
129            use half::{bf16, f16};
130
131            let elem_count = l.shape().elem_count();
132            let dev = s.device().clone();
133            let dst = match s.dtype() {
134                DType::BF16 => {
135                    let s = s.as_cuda_slice::<bf16>()?;
136                    let s = match l.contiguous_offsets() {
137                        Some((0, l)) if l == s.len() => s,
138                        Some(_) | None => candle_core::bail!("input has to be contiguous"),
139                    };
140                    assert_eq!(dev.ordinal(), self.comm.rank());
141                    assert!(elem_count > 0);
142                    let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
143                    self.comm
144                        .comm
145                        .all_reduce(s, &mut dst, &ReduceOp::Sum)
146                        .map_err(candle_core::Error::debug)?;
147                    candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
148                }
149                DType::F16 => {
150                    let s = s.as_cuda_slice::<f16>()?;
151                    let s = match l.contiguous_offsets() {
152                        Some((0, l)) if l == s.len() => s,
153                        Some(_) | None => candle_core::bail!("input has to be contiguous"),
154                    };
155                    let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
156                    self.comm
157                        .comm
158                        .all_reduce(s, &mut dst, &ReduceOp::Sum)
159                        .map_err(candle_core::Error::debug)?;
160                    candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
161                }
162                DType::F32 => {
163                    let s = s.as_cuda_slice::<f32>()?;
164                    let s = match l.contiguous_offsets() {
165                        Some((0, l)) if l == s.len() => s,
166                        Some(_) | None => candle_core::bail!("input has to be contiguous"),
167                    };
168                    let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
169                    self.comm
170                        .comm
171                        .all_reduce(s, &mut dst, &ReduceOp::Sum)
172                        .map_err(candle_core::Error::debug)?;
173                    candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
174                }
175                dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
176            };
177            Ok((dst, l.shape().clone()))
178        }
179    }
180}
181
182#[cfg(not(all(feature = "cuda", feature = "nccl")))]
183mod ops {
184    use std::sync::Arc;
185
186    use candle_core::{Device, Result, Tensor};
187
188    #[derive(Debug, Clone, Copy)]
189    pub struct Id;
190
191    impl Default for Id {
192        fn default() -> Self {
193            Self::new()
194        }
195    }
196
197    impl Id {
198        pub fn new() -> Self {
199            Self
200        }
201
202        pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
203            Self
204        }
205
206        pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
207            static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
208            &ZEROED_ID
209        }
210    }
211
212    #[derive(Debug)]
213    pub struct Comm;
214
215    impl Comm {
216        pub fn from_device(
217            _id: Id,
218            _dev: &Device,
219            _rank: usize,
220            _world_size: usize,
221        ) -> Result<Self> {
222            Ok(Self)
223        }
224
225        pub fn rank(&self) -> usize {
226            0
227        }
228
229        pub fn world_size(&self) -> usize {
230            1
231        }
232    }
233
234    #[derive(Clone, Debug)]
235    pub struct SumAllReduce;
236
237    impl SumAllReduce {
238        pub fn new(_comm: &Arc<Comm>) -> Self {
239            Self
240        }
241    }
242
243    impl super::DistributedOperation for SumAllReduce {
244        fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
245            Ok(xs.clone())
246        }
247    }
248}