1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11 master_ip: Option<String>,
12 pub master_port: u16,
13 pub port: u16,
14 pub right_port: u16,
15 right_ip: Option<String>,
16 pub rank: usize,
17 pub world_size: usize,
18}
19
20impl RingConfig {
21 pub fn load() -> Self {
23 let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24 let config: RingConfig = serde_json::from_reader(
25 &File::open(config_json).expect("Could not access Ring config JSON"),
26 )
27 .expect("Invalid JSON config");
28
29 if config.master_ip.is_none() && !config.is_master_rank() {
30 panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31 }
32 config
33 }
34
35 pub fn is_master_rank(&self) -> bool {
36 self.rank == 0
37 }
38
39 pub fn master_ip(&self) -> String {
40 self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41 }
42
43 pub fn right_ip(&self) -> String {
44 self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45 }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49 fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53 fn wait(&self) -> Result<()> {
54 Barrier::wait(self);
55 Ok(())
56 }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60 #[cfg(all(feature = "cuda", feature = "ring"))]
61 {
62 use candle_core::cuda::WrapErr;
63 candle_core::cuda::cudarc::driver::result::device::get_count()
64 .w()
65 .map(|x| x as usize)
66 }
67 #[cfg(all(not(feature = "cuda"), feature = "ring"))]
68 {
69 let config = RingConfig::load();
70 Ok(config.world_size)
71 }
72
73 #[cfg(all(feature = "cuda", feature = "nccl"))]
74 {
75 if let Ok(x) = std::env::var("MISTRALRS_MN_LOCAL_WORLD_SIZE") {
77 use std::str::FromStr;
78 Ok(usize::from_str(&x).expect("Not a number for MISTRALRS_MN_LOCAL_WORLD_SIZE!"))
79 } else {
80 use candle_core::cuda::WrapErr;
81 candle_core::cuda::cudarc::driver::result::device::get_count()
82 .w()
83 .map(|x| x as usize)
84 }
85 }
86
87 #[cfg(all(not(feature = "ring"), not(feature = "nccl")))]
88 Ok(1)
89}
90
91pub fn use_nccl() -> bool {
92 (std::env::var("MISTRALRS_NO_NCCL").is_err()
93 || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
94 && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
95}
96
97#[derive(Debug)]
99pub enum Comm {
100 #[cfg(all(feature = "cuda", feature = "nccl"))]
101 Nccl(nccl::NcclComm),
102 #[cfg(feature = "ring")]
103 Ring(ring::RingComm),
104 Dummy(dummy::DummyComm),
105}
106
107impl Comm {
108 pub fn from_device(
109 id: Id,
110 dev: &candle_core::Device,
111 rank: usize,
112 world_size: usize,
113 ) -> Result<Self> {
114 #[cfg(all(feature = "cuda", feature = "nccl"))]
115 if use_nccl() {
116 return Ok(Self::Nccl(nccl::NcclComm::from_device(
117 id, dev, rank, world_size,
118 )?));
119 }
120
121 #[cfg(feature = "ring")]
122 {
123 return Ok(Self::Ring(ring::RingComm::from_device(
124 id, dev, rank, world_size,
125 )?));
126 }
127
128 #[allow(unreachable_code)]
129 Ok(Self::Dummy(dummy::DummyComm::from_device(
130 id, dev, rank, world_size,
131 )?))
132 }
133
134 pub fn rank(&self) -> usize {
135 match self {
136 #[cfg(all(feature = "cuda", feature = "nccl"))]
137 Self::Nccl(comm) => comm.rank(),
138 #[cfg(feature = "ring")]
139 Self::Ring(comm) => comm.rank(),
140 Self::Dummy(comm) => comm.rank(),
141 }
142 }
143
144 pub fn world_size(&self) -> usize {
145 match self {
146 #[cfg(all(feature = "cuda", feature = "nccl"))]
147 Self::Nccl(comm) => comm.world_size(),
148 #[cfg(feature = "ring")]
149 Self::Ring(comm) => comm.world_size(),
150 Self::Dummy(comm) => comm.world_size(),
151 }
152 }
153}
154
155#[derive(Debug, Clone, Copy)]
157pub enum Id {
158 #[cfg(all(feature = "cuda", feature = "nccl"))]
159 Nccl(cudarc::nccl::Id),
160 Dummy,
161}
162
163impl Id {
164 pub fn new() -> Self {
165 #[cfg(all(feature = "cuda", feature = "nccl"))]
166 if use_nccl() {
167 let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
168 return Self::Nccl(id);
169 }
170 Self::Dummy
171 }
172
173 pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
174 #[cfg(all(feature = "cuda", feature = "nccl"))]
175 if use_nccl() {
176 return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
177 }
178 Self::Dummy
179 }
180
181 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
182 match self {
183 #[cfg(all(feature = "cuda", feature = "nccl"))]
184 Self::Nccl(id) => id.internal(),
185 Self::Dummy => {
186 static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
187 &ZEROED_ID
188 }
189 }
190 }
191}
192
193impl Default for Id {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[cfg(all(feature = "cuda", feature = "nccl"))]
200use candle_core::cuda::cudarc;
201
202#[cfg(all(feature = "cuda", feature = "nccl"))]
204mod nccl {
205 use candle_core::{cuda::cudarc, Device, Result};
206
207 #[derive(Debug)]
208 pub struct NcclComm {
209 comm: cudarc::nccl::Comm,
210 }
211
212 impl NcclComm {
213 pub fn from_device(
214 id: super::Id,
215 dev: &Device,
216 rank: usize,
217 world_size: usize,
218 ) -> Result<Self> {
219 if !super::use_nccl() {
220 candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
221 }
222 if !world_size.is_power_of_two() {
223 candle_core::bail!(
224 "NCCL backend requires world_size to be a power of 2, got {}",
225 world_size
226 );
227 }
228 let stream = dev.as_cuda_device()?.cuda_stream();
229 assert_eq!(rank, stream.context().ordinal());
230 let nccl_id = match id {
231 super::Id::Nccl(id) => id,
232 _ => panic!("Expected NCCL Id variant"),
233 };
234 Ok(Self {
235 comm: cudarc::nccl::Comm::from_rank(stream, rank, world_size, nccl_id)
236 .map_err(|e| e.0)
237 .expect("Failed to create `Comm`, error code"),
238 })
239 }
240
241 pub fn rank(&self) -> usize {
242 self.comm.rank()
243 }
244
245 pub fn world_size(&self) -> usize {
246 self.comm.world_size()
247 }
248
249 pub fn inner(&self) -> &cudarc::nccl::Comm {
250 &self.comm
251 }
252 }
253
254 unsafe impl Sync for NcclComm {}
256 unsafe impl Send for NcclComm {}
257}
258
259#[cfg(feature = "ring")]
261mod ring {
262 use super::RingConfig;
263 use candle_core::{Device, Result};
264
265 #[derive(Debug)]
266 pub struct RingComm {
267 config: RingConfig,
268 }
269
270 impl RingComm {
271 pub fn from_device(
272 _id: super::Id,
273 _dev: &Device,
274 _rank: usize,
275 _world_size: usize,
276 ) -> Result<Self> {
277 let config = RingConfig::load();
278 if config.world_size < 2 {
280 candle_core::bail!(
281 "Ring backend requires world_size >= 2, got {}",
282 config.world_size
283 );
284 }
285 if config.rank >= config.world_size {
286 candle_core::bail!(
287 "Ring backend invalid config: rank {} >= world_size {}",
288 config.rank,
289 config.world_size
290 );
291 }
292 if !config.world_size.is_power_of_two() {
293 candle_core::bail!(
294 "Ring backend requires world_size to be a power of 2, got {}",
295 config.world_size
296 );
297 }
298 Ok(Self { config })
299 }
300
301 pub fn rank(&self) -> usize {
302 self.config.rank
303 }
304
305 pub fn world_size(&self) -> usize {
306 self.config.world_size
307 }
308
309 pub fn config(&self) -> &RingConfig {
310 &self.config
311 }
312 }
313}
314
315mod dummy {
317 use candle_core::{Device, Result};
318
319 #[derive(Debug)]
320 pub struct DummyComm;
321
322 impl DummyComm {
323 pub fn from_device(
324 _id: super::Id,
325 _dev: &Device,
326 _rank: usize,
327 _world_size: usize,
328 ) -> Result<Self> {
329 Ok(Self)
330 }
331
332 pub fn rank(&self) -> usize {
333 0
334 }
335
336 pub fn world_size(&self) -> usize {
337 1
338 }
339 }
340}
341
342#[derive(Clone, Debug)]
344pub struct SumAllReduce {
345 #[cfg(all(feature = "cuda", feature = "nccl"))]
346 nccl: Option<nccl_ops::SumAllReduce>,
347 #[cfg(feature = "ring")]
348 ring: Option<ring_ops::SumAllReduce>,
349 dummy: Option<dummy_ops::SumAllReduce>,
350}
351
352impl SumAllReduce {
353 pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
354 match &**comm {
355 #[cfg(all(feature = "cuda", feature = "nccl"))]
356 Comm::Nccl(_) => Self {
357 #[cfg(all(feature = "cuda", feature = "nccl"))]
358 nccl: Some(nccl_ops::SumAllReduce::new(comm)),
359 #[cfg(feature = "ring")]
360 ring: None,
361 dummy: None,
362 },
363 #[cfg(feature = "ring")]
364 Comm::Ring(_) => Self {
365 #[cfg(all(feature = "cuda", feature = "nccl"))]
366 nccl: None,
367 #[cfg(feature = "ring")]
368 ring: Some(ring_ops::SumAllReduce::new(comm)),
369 dummy: None,
370 },
371 Comm::Dummy(_) => Self {
372 #[cfg(all(feature = "cuda", feature = "nccl"))]
373 nccl: None,
374 #[cfg(feature = "ring")]
375 ring: None,
376 dummy: Some(dummy_ops::SumAllReduce::new(comm)),
377 },
378 }
379 }
380
381 pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
382 #[cfg(all(feature = "cuda", feature = "nccl"))]
383 if let Some(ref nccl) = self.nccl {
384 return nccl.sum_all_reduce(xs);
385 }
386 #[cfg(feature = "ring")]
387 if let Some(ref ring) = self.ring {
388 return ring.sum_all_reduce(xs);
389 }
390 if let Some(ref dummy) = self.dummy {
391 return dummy.sum_all_reduce(xs);
392 }
393 candle_core::bail!("No valid SumAllReduce implementation available")
394 }
395}
396
397#[derive(Clone, Debug)]
398pub struct AllGather {
399 #[cfg(all(feature = "cuda", feature = "nccl"))]
400 nccl: Option<nccl_ops::AllGather>,
401 #[cfg(feature = "ring")]
402 ring: Option<ring_ops::AllGather>,
403 dummy: Option<dummy_ops::AllGather>,
404}
405
406impl AllGather {
407 pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
408 match &**comm {
409 #[cfg(all(feature = "cuda", feature = "nccl"))]
410 Comm::Nccl(_) => Self {
411 #[cfg(all(feature = "cuda", feature = "nccl"))]
412 nccl: Some(nccl_ops::AllGather::new(comm, dim)),
413 #[cfg(feature = "ring")]
414 ring: None,
415 dummy: None,
416 },
417 #[cfg(feature = "ring")]
418 Comm::Ring(_) => Self {
419 #[cfg(all(feature = "cuda", feature = "nccl"))]
420 nccl: None,
421 #[cfg(feature = "ring")]
422 ring: Some(ring_ops::AllGather::new(comm, dim)),
423 dummy: None,
424 },
425 Comm::Dummy(_) => Self {
426 #[cfg(all(feature = "cuda", feature = "nccl"))]
427 nccl: None,
428 #[cfg(feature = "ring")]
429 ring: None,
430 dummy: Some(dummy_ops::AllGather::new(comm, dim)),
431 },
432 }
433 }
434
435 pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
436 #[cfg(all(feature = "cuda", feature = "nccl"))]
437 if let Some(ref nccl) = self.nccl {
438 return nccl.all_gather(xs);
439 }
440 #[cfg(feature = "ring")]
441 if let Some(ref ring) = self.ring {
442 return ring.all_gather(xs);
443 }
444 if let Some(ref dummy) = self.dummy {
445 return dummy.all_gather(xs);
446 }
447 candle_core::bail!("No valid AllGather implementation available")
448 }
449}
450
451#[cfg(all(feature = "cuda", feature = "nccl"))]
453mod nccl_ops {
454 use std::{fmt::Debug, sync::Arc};
455
456 use candle_core::{
457 backend::BackendStorage, cuda::cudarc, CpuStorage, CustomOp1, DType, Layout, Result, Shape,
458 Tensor,
459 };
460
461 #[derive(Clone, Debug)]
462 pub struct SumAllReduce {
463 comm: Arc<super::Comm>,
464 }
465
466 impl SumAllReduce {
467 pub fn new(comm: &Arc<super::Comm>) -> Self {
468 Self { comm: comm.clone() }
469 }
470 }
471
472 impl SumAllReduce {
473 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
474 xs.apply_op1_no_bwd(self)
475 }
476 }
477
478 impl CustomOp1 for SumAllReduce {
479 fn name(&self) -> &'static str {
480 "SumAllReduce"
481 }
482
483 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
484 candle_core::bail!("SumAllReduce is never used on cpu")
485 }
486
487 fn cuda_fwd(
488 &self,
489 s: &candle_core::CudaStorage,
490 l: &Layout,
491 ) -> Result<(candle_core::CudaStorage, Shape)> {
492 use cudarc::nccl::ReduceOp;
493 use half::{bf16, f16};
494
495 let elem_count = l.shape().elem_count();
496 let dev = s.device().clone();
497
498 match self.comm.as_ref() {
499 super::Comm::Nccl(nccl_comm) => {
500 let dst = match s.dtype() {
501 DType::BF16 => {
502 let s = s.as_cuda_slice::<bf16>()?;
503 let s = match l.contiguous_offsets() {
504 Some((0, l)) if l == s.len() => s,
505 Some(_) | None => candle_core::bail!("input has to be contiguous"),
506 };
507 assert!(elem_count > 0);
508 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
509 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
510 nccl_comm
511 .inner()
512 .all_reduce(s, &mut dst, &ReduceOp::Sum)
513 .map_err(candle_core::Error::debug)?;
514 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
515 }
516 DType::F16 => {
517 let s = s.as_cuda_slice::<f16>()?;
518 let s = match l.contiguous_offsets() {
519 Some((0, l)) if l == s.len() => s,
520 Some(_) | None => candle_core::bail!("input has to be contiguous"),
521 };
522 assert!(elem_count > 0);
523 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
524 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
525 nccl_comm
526 .inner()
527 .all_reduce(s, &mut dst, &ReduceOp::Sum)
528 .map_err(candle_core::Error::debug)?;
529 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
530 }
531 DType::F32 => {
532 let s = s.as_cuda_slice::<f32>()?;
533 let s = match l.contiguous_offsets() {
534 Some((0, l)) if l == s.len() => s,
535 Some(_) | None => candle_core::bail!("input has to be contiguous"),
536 };
537 assert!(elem_count > 0);
538 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
539 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
540 nccl_comm
541 .inner()
542 .all_reduce(s, &mut dst, &ReduceOp::Sum)
543 .map_err(candle_core::Error::debug)?;
544 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
545 }
546 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
547 };
548 Ok((dst, l.shape().clone()))
549 }
550 _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
551 }
552 }
553 }
554
555 #[derive(Clone, Debug)]
556 pub struct AllGather {
557 comm: Arc<super::Comm>,
558 dim: usize,
559 }
560
561 impl AllGather {
562 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
563 Self {
564 comm: comm.clone(),
565 dim,
566 }
567 }
568 }
569
570 impl AllGather {
571 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
572 xs.apply_op1_no_bwd(self)
573 }
574 }
575
576 impl CustomOp1 for AllGather {
577 fn name(&self) -> &'static str {
578 "AllGather"
579 }
580
581 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
582 candle_core::bail!("AllGather is never used on cpu")
583 }
584
585 fn cuda_fwd(
586 &self,
587 s: &candle_core::CudaStorage,
588 l: &Layout,
589 ) -> Result<(candle_core::CudaStorage, Shape)> {
590 use half::{bf16, f16};
591
592 let mut out_shape = l.shape().dims().to_vec();
593 out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
594 let out_shape = Shape::from(out_shape);
595
596 let elem_count = out_shape.elem_count();
597 let dev = s.device().clone();
598
599 match self.comm.as_ref() {
600 super::Comm::Nccl(nccl_comm) => {
601 let dst = match s.dtype() {
602 DType::BF16 => {
603 let s = s.as_cuda_slice::<bf16>()?;
604 let s = match l.contiguous_offsets() {
605 Some((0, l)) if l == s.len() => s,
606 Some(_) | None => candle_core::bail!("input has to be contiguous"),
607 };
608 assert!(elem_count > 0);
609 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
610 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
611 nccl_comm
612 .inner()
613 .all_gather(s, &mut dst)
614 .map_err(candle_core::Error::debug)?;
615 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
616 }
617 DType::F16 => {
618 let s = s.as_cuda_slice::<f16>()?;
619 let s = match l.contiguous_offsets() {
620 Some((0, l)) if l == s.len() => s,
621 Some(_) | None => candle_core::bail!("input has to be contiguous"),
622 };
623 assert!(elem_count > 0);
624 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
625 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
626 nccl_comm
627 .inner()
628 .all_gather(s, &mut dst)
629 .map_err(candle_core::Error::debug)?;
630 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
631 }
632 DType::F32 => {
633 let s = s.as_cuda_slice::<f32>()?;
634 let s = match l.contiguous_offsets() {
635 Some((0, l)) if l == s.len() => s,
636 Some(_) | None => candle_core::bail!("input has to be contiguous"),
637 };
638 assert!(elem_count > 0);
639 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
640 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
641 nccl_comm
642 .inner()
643 .all_gather(s, &mut dst)
644 .map_err(candle_core::Error::debug)?;
645 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
646 }
647 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
648 };
649 Ok((dst, out_shape))
650 }
651 _ => candle_core::bail!("AllGather requires NCCL backend"),
652 }
653 }
654 }
655}
656
657#[cfg(feature = "ring")]
659mod ring_ops {
660 use std::{
661 collections::HashMap,
662 sync::{Arc, Mutex, OnceLock},
663 time::{Duration, Instant},
664 };
665
666 use std::io::{Read, Write};
667 use std::net::{TcpListener, TcpStream};
668
669 type SharedTcpStream = Arc<Mutex<TcpStream>>;
671 type LeftRight = (SharedTcpStream, SharedTcpStream);
672
673 use candle_core::{
674 backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
675 };
676
677 use super::RingConfig;
678
679 static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
681
682 fn get_ring_streams(config: &RingConfig) -> LeftRight {
683 LEFT_RIGHT_STREAMS
684 .get_or_init(|| {
685 let cur_port = config.port;
686
687 let right_ip = config.right_ip();
688 let right_port = config.right_port;
689
690 let left_listener =
691 TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
692
693 let start = Instant::now();
694 let right = loop {
696 match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
697 Ok(s) => break s,
698 Err(_) if start.elapsed() > Duration::from_secs(10) => {
699 panic!("Failed to connect to right node due to 10-second timeout");
700 }
701 Err(_) => continue,
702 }
703 };
704
705 let (left, _) = left_listener.accept().expect("accept left neighbour");
707
708 left.set_nodelay(true).unwrap();
709 left.set_nonblocking(false).unwrap();
710 right.set_nodelay(true).unwrap();
711 right.set_nonblocking(false).unwrap();
712
713 (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
714 })
715 .clone()
716 }
717
718 #[derive(Clone, Debug)]
719 pub struct SumAllReduce {
720 left: SharedTcpStream,
721 right: SharedTcpStream,
722 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
723 }
724
725 impl SumAllReduce {
726 pub fn new(comm: &Arc<super::Comm>) -> Self {
727 match &**comm {
728 super::Comm::Ring(ring_comm) => {
729 let (left, right) = get_ring_streams(ring_comm.config());
730 Self {
731 left,
732 right,
733 buffers: Arc::new(Mutex::new(HashMap::new())),
734 }
735 }
736 _ => panic!("SumAllReduce requires Ring backend"),
737 }
738 }
739
740 fn run<T: WithDType + Copy>(
741 &self,
742 x: &[T],
743 dims: &[usize],
744 device: &Device,
745 ) -> Result<Tensor> {
746 let nbytes = x.len() * std::mem::size_of_val(x);
747
748 let right = self.right.clone();
751 let left = self.left.clone();
752
753 let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
755
756 let mut buffers_guard = self.buffers.lock().map_err(|e| {
758 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
759 })?;
760 let recv_buf = buffers_guard
761 .entry(nbytes)
762 .or_insert_with(|| vec![0u8; nbytes]);
763
764 let mut right_guard = right.lock().map_err(|e| {
766 candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
767 })?;
768 let mut left_guard = left.lock().map_err(|e| {
769 candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
770 })?;
771
772 if nbytes <= 8 * 1024 {
777 right_guard
779 .write_all(data_bytes)
780 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
781
782 left_guard
783 .read_exact(recv_buf)
784 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
785 } else {
786 const CHUNK_SIZE: usize = 64 * 1024; let mut offset = 0;
789
790 while offset < nbytes {
791 let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
792
793 right_guard
795 .write_all(&data_bytes[offset..offset + len])
796 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
797
798 left_guard
800 .read_exact(&mut recv_buf[offset..offset + len])
801 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
802
803 offset += len;
804 }
805 }
806
807 drop(left_guard);
808 drop(right_guard);
809
810 let received: &[T] =
813 unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
814
815 Tensor::from_slice(received, dims, device)
816 }
817
818 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
819 let storage = xs.storage_and_layout().0;
820 let cpu_storage = match &*storage {
821 Storage::Cpu(storage) => storage,
822 Storage::Cuda(storage) => &storage.to_cpu_storage()?,
823 Storage::Metal(storage) => &storage.to_cpu_storage()?,
824 };
825
826 let delta = match cpu_storage {
827 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
828 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
829 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
830 _ => candle_core::bail!("Unsupported dtype for ring backend"),
831 };
832
833 xs + delta
834 }
835 }
836
837 #[derive(Clone, Debug)]
838 pub struct AllGather {
839 left: SharedTcpStream,
840 right: SharedTcpStream,
841 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
842 dim: usize,
843 world_size: usize,
844 rank: usize,
845 }
846
847 impl AllGather {
848 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
849 match &**comm {
850 super::Comm::Ring(ring_comm) => {
851 let (left, right) = get_ring_streams(ring_comm.config());
852 Self {
853 left,
854 right,
855 buffers: Arc::new(Mutex::new(HashMap::new())),
856 dim,
857 world_size: ring_comm.world_size(),
858 rank: ring_comm.rank(),
859 }
860 }
861 _ => panic!("AllGather requires Ring backend"),
862 }
863 }
864
865 fn run<T: WithDType + Copy + Default>(
866 &self,
867 x: &[T],
868 dims: &[usize],
869 device: &Device,
870 ) -> Result<Tensor> {
871 if self.dim >= dims.len() {
873 candle_core::bail!(
874 "AllGather: invalid dimension {} for tensor of rank {}",
875 self.dim,
876 dims.len()
877 );
878 }
879 let elem_cnt = x.len();
880 let nbytes = elem_cnt * std::mem::size_of_val(x);
881
882 let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
884
885 let start = self.rank * elem_cnt;
887 out[start..start + elem_cnt].copy_from_slice(x);
888
889 let right = self.right.clone();
890 let left = self.left.clone();
891 let mut send_piece: &[T] = x;
892
893 for step in 0..(self.world_size - 1) {
894 let bytes =
896 unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
897 {
898 let mut rg = right.lock().map_err(|e| {
899 candle_core::Error::msg(format!(
900 "Failed to lock right stream mutex: {:?}",
901 e
902 ))
903 })?;
904 rg.write_all(bytes)
905 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
906 }
907
908 let mut bg = self.buffers.lock().map_err(|e| {
910 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
911 })?;
912 let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
913 {
914 let mut lg = left.lock().map_err(|e| {
915 candle_core::Error::msg(format!(
916 "Failed to lock left stream mutex: {:?}",
917 e
918 ))
919 })?;
920 lg.read_exact(buf)
921 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
922 }
923 let recv_piece: &[T] =
924 unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
925
926 let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
928 let dst = src_rank * elem_cnt;
929 out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
930
931 send_piece = recv_piece;
933 }
934
935 let mut out_dims = dims.to_vec();
936 out_dims[self.dim] *= self.world_size;
937 Tensor::from_slice(&out, out_dims, device)
938 }
939
940 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
941 let storage = xs.storage_and_layout().0;
942 let cpu_storage = match &*storage {
943 Storage::Cpu(s) => s,
944 Storage::Cuda(s) => &s.to_cpu_storage()?,
945 Storage::Metal(s) => &s.to_cpu_storage()?,
946 };
947
948 match cpu_storage {
949 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
950 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
951 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
952 _ => candle_core::bail!("Unsupported dtype for ring backend"),
953 }
954 }
955 }
956}
957
958mod dummy_ops {
960 use candle_core::{Result, Tensor};
961 use std::sync::Arc;
962
963 #[derive(Clone, Debug)]
964 pub struct SumAllReduce;
965
966 impl SumAllReduce {
967 pub fn new(_comm: &Arc<super::Comm>) -> Self {
968 Self
969 }
970
971 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
972 Ok(xs.clone())
973 }
974 }
975
976 #[derive(Clone, Debug)]
977 pub struct AllGather;
978
979 impl AllGather {
980 pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
981 Self
982 }
983
984 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
985 Ok(xs.clone())
986 }
987 }
988}