1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub use ops::{AllGather, Comm, Id, SumAllReduce};
5pub mod layers;
6pub mod socket;
7
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Deserialize, Serialize)]
11pub struct RingConfig {
12 master_ip: Option<String>,
13 pub master_port: u16,
14 pub port: u16,
15 pub right_port: u16,
16 right_ip: Option<String>,
17 pub rank: usize,
18 pub world_size: usize,
19}
20
21impl RingConfig {
22 pub fn load() -> Self {
24 let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
25 let config: RingConfig = serde_json::from_reader(
26 &File::open(config_json).expect("Could not access Ring config JSON"),
27 )
28 .expect("Invalid JSON config");
29
30 if config.master_ip.is_none() && !config.is_master_rank() {
31 panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
32 }
33 config
34 }
35
36 pub fn is_master_rank(&self) -> bool {
37 self.rank == 0
38 }
39
40 pub fn master_ip(&self) -> String {
41 self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
42 }
43
44 pub fn right_ip(&self) -> String {
45 self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
46 }
47}
48
49pub trait BarrierLike: Debug + Send + Sync {
50 fn wait(&self) -> Result<()>;
51}
52
53impl BarrierLike for Barrier {
54 fn wait(&self) -> Result<()> {
55 Barrier::wait(self);
56 Ok(())
57 }
58}
59
60pub fn get_global_tp_size_from_devices() -> Result<usize> {
61 #[cfg(feature = "cuda")]
62 {
63 use candle_core::cuda::WrapErr;
64 candle_core::cuda::cudarc::driver::result::device::get_count()
65 .w()
66 .map(|x| x as usize)
67 }
68 #[cfg(feature = "ring")]
69 {
70 let config = RingConfig::load();
71 Ok(config.world_size)
72 }
73
74 #[cfg(not(any(feature = "cuda", feature = "ring")))]
75 Ok(1)
76}
77
78pub fn use_nccl() -> bool {
79 (std::env::var("MISTRALRS_NO_NCCL").is_err()
80 || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
81 && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
82}
83
84#[cfg(all(feature = "cuda", feature = "nccl"))]
85mod ops {
86 use std::{fmt::Debug, ops::Deref, sync::Arc};
87
88 use candle_core::{
89 backend::BackendStorage, cuda::cudarc, cuda_backend::WrapErr, CpuStorage, CustomOp1, DType,
90 Device, Layout, Result, Shape, Tensor,
91 };
92
93 #[derive(Debug, Clone, Copy)]
94 pub struct Id(cudarc::nccl::Id);
95
96 impl Id {
97 pub fn new() -> Self {
98 let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
99 Self(id)
100 }
101
102 pub fn uninit(internal: [::core::ffi::c_char; 128usize]) -> Self {
103 Self(cudarc::nccl::Id::uninit(internal))
104 }
105
106 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
107 self.0.internal()
108 }
109 }
110
111 #[derive(Debug)]
112 pub struct Comm {
113 comm: cudarc::nccl::Comm,
114 }
115
116 impl Comm {
117 pub fn from_device(id: Id, dev: &Device, rank: usize, world_size: usize) -> Result<Self> {
118 let device = dev.as_cuda_device()?.cuda_device();
119 assert_eq!(rank, device.ordinal());
120 Ok(Self {
121 comm: cudarc::nccl::Comm::from_rank(device, rank, world_size, id.0)
122 .map_err(|e| e.0)
123 .expect("Failed to create `Comm`, error code"),
124 })
125 }
126 }
127
128 unsafe impl Sync for Comm {}
130 unsafe impl Send for Comm {}
131
132 impl Deref for Comm {
133 type Target = cudarc::nccl::Comm;
134
135 fn deref(&self) -> &Self::Target {
136 &self.comm
137 }
138 }
139
140 #[derive(Clone, Debug)]
141 pub struct SumAllReduce {
142 comm: Arc<Comm>,
143 }
144
145 impl SumAllReduce {
146 pub fn new(comm: &Arc<Comm>) -> Self {
147 Self { comm: comm.clone() }
148 }
149 }
150
151 impl SumAllReduce {
152 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
153 xs.apply_op1_no_bwd(self)
154 }
155 }
156
157 impl CustomOp1 for SumAllReduce {
158 fn name(&self) -> &'static str {
159 "SumAllReduce"
160 }
161
162 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
163 candle_core::bail!("SumAllReduce is never used on cpu")
164 }
165
166 fn cuda_fwd(
167 &self,
168 s: &candle_core::CudaStorage,
169 l: &Layout,
170 ) -> Result<(candle_core::CudaStorage, Shape)> {
171 use cudarc::{driver::DeviceSlice, nccl::ReduceOp};
172 use half::{bf16, f16};
173
174 let elem_count = l.shape().elem_count();
175 let dev = s.device().clone();
176 let dst = match s.dtype() {
177 DType::BF16 => {
178 let s = s.as_cuda_slice::<bf16>()?;
179 let s = match l.contiguous_offsets() {
180 Some((0, l)) if l == s.len() => s,
181 Some(_) | None => candle_core::bail!("input has to be contiguous"),
182 };
183 assert_eq!(dev.ordinal(), self.comm.rank());
184 assert!(elem_count > 0);
185 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
186 self.comm
187 .comm
188 .all_reduce(s, &mut dst, &ReduceOp::Sum)
189 .map_err(candle_core::Error::debug)?;
190 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
191 }
192 DType::F16 => {
193 let s = s.as_cuda_slice::<f16>()?;
194 let s = match l.contiguous_offsets() {
195 Some((0, l)) if l == s.len() => s,
196 Some(_) | None => candle_core::bail!("input has to be contiguous"),
197 };
198 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
199 self.comm
200 .comm
201 .all_reduce(s, &mut dst, &ReduceOp::Sum)
202 .map_err(candle_core::Error::debug)?;
203 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
204 }
205 DType::F32 => {
206 let s = s.as_cuda_slice::<f32>()?;
207 let s = match l.contiguous_offsets() {
208 Some((0, l)) if l == s.len() => s,
209 Some(_) | None => candle_core::bail!("input has to be contiguous"),
210 };
211 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
212 self.comm
213 .comm
214 .all_reduce(s, &mut dst, &ReduceOp::Sum)
215 .map_err(candle_core::Error::debug)?;
216 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
217 }
218 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
219 };
220 Ok((dst, l.shape().clone()))
221 }
222 }
223
224 #[derive(Clone, Debug)]
225 pub struct AllGather {
226 comm: Arc<Comm>,
227 dim: usize,
228 }
229
230 impl AllGather {
231 pub fn new(comm: &Arc<Comm>, dim: usize) -> Self {
232 Self {
233 comm: comm.clone(),
234 dim,
235 }
236 }
237 }
238
239 impl AllGather {
240 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
241 xs.apply_op1_no_bwd(self)
242 }
243 }
244
245 impl CustomOp1 for AllGather {
246 fn name(&self) -> &'static str {
247 "AllGather"
248 }
249
250 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
251 candle_core::bail!("AllGather is never used on cpu")
252 }
253
254 fn cuda_fwd(
255 &self,
256 s: &candle_core::CudaStorage,
257 l: &Layout,
258 ) -> Result<(candle_core::CudaStorage, Shape)> {
259 use cudarc::driver::DeviceSlice;
260 use half::{bf16, f16};
261
262 let mut out_shape = l.shape().dims().to_vec();
263 out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
264 let out_shape = Shape::from(out_shape);
265
266 let elem_count = out_shape.elem_count();
267 let dev = s.device().clone();
268 let dst = match s.dtype() {
269 DType::BF16 => {
270 let s = s.as_cuda_slice::<bf16>()?;
271 let s = match l.contiguous_offsets() {
272 Some((0, l)) if l == s.len() => s,
273 Some(_) | None => candle_core::bail!("input has to be contiguous"),
274 };
275 assert_eq!(dev.ordinal(), self.comm.rank());
276 assert!(elem_count > 0);
277 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
278 self.comm
279 .comm
280 .all_gather(s, &mut dst)
281 .map_err(candle_core::Error::debug)?;
282 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
283 }
284 DType::F16 => {
285 let s = s.as_cuda_slice::<f16>()?;
286 let s = match l.contiguous_offsets() {
287 Some((0, l)) if l == s.len() => s,
288 Some(_) | None => candle_core::bail!("input has to be contiguous"),
289 };
290 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
291 self.comm
292 .comm
293 .all_gather(s, &mut dst)
294 .map_err(candle_core::Error::debug)?;
295 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
296 }
297 DType::F32 => {
298 let s = s.as_cuda_slice::<f32>()?;
299 let s = match l.contiguous_offsets() {
300 Some((0, l)) if l == s.len() => s,
301 Some(_) | None => candle_core::bail!("input has to be contiguous"),
302 };
303 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
304 self.comm
305 .comm
306 .all_gather(s, &mut dst)
307 .map_err(candle_core::Error::debug)?;
308 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
309 }
310 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
311 };
312 Ok((dst, out_shape))
313 }
314 }
315}
316
317#[cfg(feature = "ring")]
318mod ops {
319 use std::{
320 collections::HashMap,
321 fmt::Debug,
322 sync::{Arc, Mutex, OnceLock},
323 time::{Duration, Instant},
324 };
325
326 use std::io::{Read, Write};
327 use std::net::{TcpListener, TcpStream};
328
329 type SharedTcpStream = Arc<Mutex<TcpStream>>;
331 type LeftRight = (SharedTcpStream, SharedTcpStream);
332
333 use candle_core::{
334 backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
335 };
336
337 use super::RingConfig;
338
339 static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
341
342 fn get_ring_streams(config: &RingConfig) -> LeftRight {
343 LEFT_RIGHT_STREAMS
344 .get_or_init(|| {
345 let cur_port = config.port;
346
347 let right_ip = config.right_ip();
348 let right_port = config.right_port;
349
350 let left_listener =
351 TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
352
353 let start = Instant::now();
354 let right = loop {
356 match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
357 Ok(s) => break s,
358 Err(_) if start.elapsed() > Duration::from_secs(10) => {
359 panic!("Failed to connect to right node due to 10-second timeout");
360 }
361 Err(_) => continue,
362 }
363 };
364
365 let (left, _) = left_listener.accept().expect("accept left neighbour");
367
368 left.set_nodelay(true).unwrap();
369 left.set_nonblocking(false).unwrap();
370 right.set_nodelay(true).unwrap();
371 right.set_nonblocking(false).unwrap();
372
373 (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
374 })
375 .clone()
376 }
377
378 #[derive(Debug, Clone, Copy)]
379 pub struct Id;
380
381 impl Default for Id {
382 fn default() -> Self {
383 Self::new()
384 }
385 }
386
387 impl Id {
388 pub fn new() -> Self {
389 Self
390 }
391
392 pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
393 Self
394 }
395
396 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
397 static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
398 &ZEROED_ID
399 }
400 }
401
402 #[derive(Debug)]
403 pub struct Comm {
404 config: RingConfig,
405 }
406
407 impl Comm {
408 pub fn from_device(
409 _id: Id,
410 _dev: &Device,
411 _rank: usize,
412 _world_size: usize,
413 ) -> Result<Self> {
414 let config = RingConfig::load();
415 if config.world_size < 2 {
417 candle_core::bail!(
418 "Ring backend requires world_size >= 2, got {}",
419 config.world_size
420 );
421 }
422 if config.rank >= config.world_size {
423 candle_core::bail!(
424 "Ring backend invalid config: rank {} >= world_size {}",
425 config.rank,
426 config.world_size
427 );
428 }
429 Ok(Self { config })
430 }
431
432 pub fn rank(&self) -> usize {
433 self.config.rank
434 }
435
436 pub fn world_size(&self) -> usize {
437 self.config.world_size
438 }
439 }
440
441 #[derive(Clone, Debug)]
442 pub struct SumAllReduce {
443 left: SharedTcpStream,
444 right: SharedTcpStream,
445 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
446 }
447
448 impl SumAllReduce {
449 pub fn new(comm: &Arc<Comm>) -> Self {
450 let (left, right) = get_ring_streams(&comm.config);
451 Self {
452 left,
453 right,
454 buffers: Arc::new(Mutex::new(HashMap::new())),
455 }
456 }
457
458 fn run<T: WithDType + Copy>(
459 &self,
460 x: &[T],
461 dims: &[usize],
462 device: &Device,
463 ) -> Result<Tensor> {
464 let nbytes = x.len() * std::mem::size_of_val(x);
465 let right = self.right.clone();
470 let left = self.left.clone();
471
472 let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
474
475 let mut buffers_guard = self.buffers.lock().map_err(|e| {
477 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
478 })?;
479 let recv_buf = buffers_guard
480 .entry(nbytes)
481 .or_insert_with(|| vec![0u8; nbytes]);
482
483 let mut right_guard = right.lock().map_err(|e| {
485 candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
486 })?;
487 let mut left_guard = left.lock().map_err(|e| {
488 candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
489 })?;
490
491 if nbytes <= 8 * 1024 {
496 right_guard
498 .write_all(data_bytes)
499 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
500
501 left_guard
502 .read_exact(recv_buf)
503 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
504 } else {
505 const CHUNK_SIZE: usize = 64 * 1024; let mut offset = 0;
508
509 while offset < nbytes {
510 let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
511
512 right_guard
514 .write_all(&data_bytes[offset..offset + len])
515 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
516
517 left_guard
519 .read_exact(&mut recv_buf[offset..offset + len])
520 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
521
522 offset += len;
523 }
524 }
525
526 drop(left_guard);
527 drop(right_guard);
528 let received: &[T] =
533 unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
534
535 Tensor::from_slice(received, dims, device)
536 }
537 }
538
539 impl SumAllReduce {
540 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
541 let storage = xs.storage_and_layout().0;
542 let cpu_storage = match &*storage {
543 Storage::Cpu(storage) => storage,
544 Storage::Cuda(storage) => &storage.to_cpu_storage()?,
545 Storage::Metal(storage) => &storage.to_cpu_storage()?,
546 };
547
548 let delta = match cpu_storage {
549 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
550 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
551 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
552 _ => candle_core::bail!("Unsupported dtype for ring backend"),
553 };
554
555 xs + delta
556 }
557 }
558
559 #[derive(Clone, Debug)]
560 pub struct AllGather {
561 left: SharedTcpStream,
562 right: SharedTcpStream,
563 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
564 dim: usize,
565 world_size: usize,
566 rank: usize,
567 }
568
569 impl AllGather {
570 pub fn new(comm: &Arc<Comm>, dim: usize) -> Self {
571 let (left, right) = get_ring_streams(&comm.config);
572 Self {
573 left,
574 right,
575 buffers: Arc::new(Mutex::new(HashMap::new())),
576 dim,
577 world_size: comm.world_size(),
578 rank: comm.rank(),
579 }
580 }
581
582 fn run<T: WithDType + Copy + Default>(
583 &self,
584 x: &[T],
585 dims: &[usize],
586 device: &Device,
587 ) -> Result<Tensor> {
588 if self.dim >= dims.len() {
590 candle_core::bail!(
591 "AllGather: invalid dimension {} for tensor of rank {}",
592 self.dim,
593 dims.len()
594 );
595 }
596 let elem_cnt = x.len();
597 let nbytes = elem_cnt * std::mem::size_of_val(x);
598
599 let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
601
602 let start = self.rank * elem_cnt;
604 out[start..start + elem_cnt].copy_from_slice(x);
605
606 let right = self.right.clone();
607 let left = self.left.clone();
608 let mut send_piece: &[T] = x;
609
610 for step in 0..(self.world_size - 1) {
611 let bytes =
613 unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
614 {
615 let mut rg = right.lock().map_err(|e| {
616 candle_core::Error::msg(format!(
617 "Failed to lock right stream mutex: {:?}",
618 e
619 ))
620 })?;
621 rg.write_all(bytes)
622 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
623 }
624
625 let mut bg = self.buffers.lock().map_err(|e| {
627 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
628 })?;
629 let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
630 {
631 let mut lg = left.lock().map_err(|e| {
632 candle_core::Error::msg(format!(
633 "Failed to lock left stream mutex: {:?}",
634 e
635 ))
636 })?;
637 lg.read_exact(buf)
638 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
639 }
640 let recv_piece: &[T] =
641 unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
642
643 let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
645 let dst = src_rank * elem_cnt;
646 out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
647
648 send_piece = recv_piece;
650 }
651
652 let mut out_dims = dims.to_vec();
653 out_dims[self.dim] *= self.world_size;
654 Tensor::from_slice(&out, out_dims, device)
655 }
656
657 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
658 let storage = xs.storage_and_layout().0;
659 let cpu_storage = match &*storage {
660 Storage::Cpu(s) => s,
661 Storage::Cuda(s) => &s.to_cpu_storage()?,
662 Storage::Metal(s) => &s.to_cpu_storage()?,
663 };
664
665 match cpu_storage {
666 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
667 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
668 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
669 _ => candle_core::bail!("Unsupported dtype for ring backend"),
670 }
671 }
672 }
673}
674
675#[cfg(not(any(all(feature = "cuda", feature = "nccl"), feature = "ring")))]
676mod ops {
677 use std::sync::Arc;
678
679 use candle_core::{Device, Result, Tensor};
680
681 #[derive(Debug, Clone, Copy)]
682 pub struct Id;
683
684 impl Default for Id {
685 fn default() -> Self {
686 Self::new()
687 }
688 }
689
690 impl Id {
691 pub fn new() -> Self {
692 Self
693 }
694
695 pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
696 Self
697 }
698
699 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
700 static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
701 &ZEROED_ID
702 }
703 }
704
705 #[derive(Debug)]
706 pub struct Comm;
707
708 impl Comm {
709 pub fn from_device(
710 _id: Id,
711 _dev: &Device,
712 _rank: usize,
713 _world_size: usize,
714 ) -> Result<Self> {
715 Ok(Self)
716 }
717
718 pub fn rank(&self) -> usize {
719 0
720 }
721
722 pub fn world_size(&self) -> usize {
723 1
724 }
725 }
726
727 #[derive(Clone, Debug)]
728 pub struct SumAllReduce;
729
730 impl SumAllReduce {
731 pub fn new(_comm: &Arc<Comm>) -> Self {
732 Self
733 }
734 }
735
736 impl SumAllReduce {
737 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
738 Ok(xs.clone())
739 }
740 }
741
742 #[derive(Clone, Debug)]
743 pub struct AllGather;
744
745 impl AllGather {
746 pub fn new(_comm: &Arc<Comm>, _dim: usize) -> Self {
747 Self
748 }
749 }
750
751 impl AllGather {
752 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
753 Ok(xs.clone())
754 }
755 }
756}