mistralrs_quant/distributed/
mod.rs1use std::{fmt::Debug, sync::Barrier};
2
3use candle_core::{Result, Tensor};
4pub use ops::{Comm, Id, SumAllReduce};
5pub mod layers;
6pub mod socket;
7
8pub trait BarrierLike: Debug + Send + Sync {
9 fn wait(&self) -> Result<()>;
10}
11
12impl BarrierLike for Barrier {
13 fn wait(&self) -> Result<()> {
14 Barrier::wait(self);
15 Ok(())
16 }
17}
18
19pub fn get_global_tp_size_from_devices() -> Result<usize> {
20 #[cfg(feature = "cuda")]
21 {
22 use candle_core::cuda::WrapErr;
23 candle_core::cuda::cudarc::driver::result::device::get_count()
24 .w()
25 .map(|x| x as usize)
26 }
27 #[cfg(not(feature = "cuda"))]
28 Ok(1)
29}
30
31pub fn use_nccl() -> bool {
32 (std::env::var("MISTRALRS_NO_NCCL").is_err()
33 || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
34 && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
35}
36
37pub trait DistributedOperation {
38 fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor>;
39}
40
41#[cfg(all(feature = "cuda", feature = "nccl"))]
42mod ops {
43 use std::{fmt::Debug, ops::Deref, sync::Arc};
44
45 use candle_core::{
46 backend::BackendStorage, cuda::cudarc, cuda_backend::WrapErr, CpuStorage, CustomOp1, DType,
47 Device, Layout, Result, Shape, Tensor,
48 };
49
50 #[derive(Debug, Clone, Copy)]
51 pub struct Id(cudarc::nccl::Id);
52
53 impl Id {
54 pub fn new() -> Self {
55 let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
56 Self(id)
57 }
58
59 pub fn uninit(internal: [::core::ffi::c_char; 128usize]) -> Self {
60 Self(cudarc::nccl::Id::uninit(internal))
61 }
62
63 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
64 self.0.internal()
65 }
66 }
67
68 #[derive(Debug)]
69 pub struct Comm {
70 comm: cudarc::nccl::Comm,
71 }
72
73 impl Comm {
74 pub fn from_device(id: Id, dev: &Device, rank: usize, world_size: usize) -> Result<Self> {
75 let device = dev.as_cuda_device()?.cuda_device();
76 assert_eq!(rank, device.ordinal());
77 Ok(Self {
78 comm: cudarc::nccl::Comm::from_rank(device, rank, world_size, id.0)
79 .map_err(|e| e.0)
80 .expect("Failed to create `Comm`, error code"),
81 })
82 }
83 }
84
85 unsafe impl Sync for Comm {}
87 unsafe impl Send for Comm {}
88
89 impl Deref for Comm {
90 type Target = cudarc::nccl::Comm;
91
92 fn deref(&self) -> &Self::Target {
93 &self.comm
94 }
95 }
96
97 #[derive(Clone, Debug)]
98 pub struct SumAllReduce {
99 comm: Arc<Comm>,
100 }
101
102 impl SumAllReduce {
103 pub fn new(comm: &Arc<Comm>) -> Self {
104 Self { comm: comm.clone() }
105 }
106 }
107
108 impl super::DistributedOperation for SumAllReduce {
109 fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
110 xs.apply_op1_no_bwd(self)
111 }
112 }
113
114 impl CustomOp1 for SumAllReduce {
115 fn name(&self) -> &'static str {
116 "SumAllReduce"
117 }
118
119 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
120 candle_core::bail!("SumAllReduce is never used on cpu")
121 }
122
123 fn cuda_fwd(
124 &self,
125 s: &candle_core::CudaStorage,
126 l: &Layout,
127 ) -> Result<(candle_core::CudaStorage, Shape)> {
128 use cudarc::{driver::DeviceSlice, nccl::ReduceOp};
129 use half::{bf16, f16};
130
131 let elem_count = l.shape().elem_count();
132 let dev = s.device().clone();
133 let dst = match s.dtype() {
134 DType::BF16 => {
135 let s = s.as_cuda_slice::<bf16>()?;
136 let s = match l.contiguous_offsets() {
137 Some((0, l)) if l == s.len() => s,
138 Some(_) | None => candle_core::bail!("input has to be contiguous"),
139 };
140 assert_eq!(dev.ordinal(), self.comm.rank());
141 assert!(elem_count > 0);
142 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
143 self.comm
144 .comm
145 .all_reduce(s, &mut dst, &ReduceOp::Sum)
146 .map_err(candle_core::Error::debug)?;
147 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
148 }
149 DType::F16 => {
150 let s = s.as_cuda_slice::<f16>()?;
151 let s = match l.contiguous_offsets() {
152 Some((0, l)) if l == s.len() => s,
153 Some(_) | None => candle_core::bail!("input has to be contiguous"),
154 };
155 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
156 self.comm
157 .comm
158 .all_reduce(s, &mut dst, &ReduceOp::Sum)
159 .map_err(candle_core::Error::debug)?;
160 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
161 }
162 DType::F32 => {
163 let s = s.as_cuda_slice::<f32>()?;
164 let s = match l.contiguous_offsets() {
165 Some((0, l)) if l == s.len() => s,
166 Some(_) | None => candle_core::bail!("input has to be contiguous"),
167 };
168 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
169 self.comm
170 .comm
171 .all_reduce(s, &mut dst, &ReduceOp::Sum)
172 .map_err(candle_core::Error::debug)?;
173 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
174 }
175 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
176 };
177 Ok((dst, l.shape().clone()))
178 }
179 }
180}
181
182#[cfg(not(all(feature = "cuda", feature = "nccl")))]
183mod ops {
184 use std::sync::Arc;
185
186 use candle_core::{Device, Result, Tensor};
187
188 #[derive(Debug, Clone, Copy)]
189 pub struct Id;
190
191 impl Default for Id {
192 fn default() -> Self {
193 Self::new()
194 }
195 }
196
197 impl Id {
198 pub fn new() -> Self {
199 Self
200 }
201
202 pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
203 Self
204 }
205
206 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
207 static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
208 &ZEROED_ID
209 }
210 }
211
212 #[derive(Debug)]
213 pub struct Comm;
214
215 impl Comm {
216 pub fn from_device(
217 _id: Id,
218 _dev: &Device,
219 _rank: usize,
220 _world_size: usize,
221 ) -> Result<Self> {
222 Ok(Self)
223 }
224
225 pub fn rank(&self) -> usize {
226 0
227 }
228
229 pub fn world_size(&self) -> usize {
230 1
231 }
232 }
233
234 #[derive(Clone, Debug)]
235 pub struct SumAllReduce;
236
237 impl SumAllReduce {
238 pub fn new(_comm: &Arc<Comm>) -> Self {
239 Self
240 }
241 }
242
243 impl super::DistributedOperation for SumAllReduce {
244 fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
245 Ok(xs.clone())
246 }
247 }
248}