1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11 master_ip: Option<String>,
12 pub master_port: u16,
13 pub port: u16,
14 pub right_port: u16,
15 right_ip: Option<String>,
16 pub rank: usize,
17 pub world_size: usize,
18}
19
20impl RingConfig {
21 pub fn load() -> Self {
23 let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24 let config: RingConfig = serde_json::from_reader(
25 &File::open(config_json).expect("Could not access Ring config JSON"),
26 )
27 .expect("Invalid JSON config");
28
29 if config.master_ip.is_none() && !config.is_master_rank() {
30 panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31 }
32 config
33 }
34
35 pub fn is_master_rank(&self) -> bool {
36 self.rank == 0
37 }
38
39 pub fn master_ip(&self) -> String {
40 self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41 }
42
43 pub fn right_ip(&self) -> String {
44 self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45 }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49 fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53 fn wait(&self) -> Result<()> {
54 Barrier::wait(self);
55 Ok(())
56 }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60 #[cfg(feature = "cuda")]
61 {
62 use candle_core::cuda::WrapErr;
63 candle_core::cuda::cudarc::driver::result::device::get_count()
64 .w()
65 .map(|x| x as usize)
66 }
67 #[cfg(feature = "ring")]
68 {
69 let config = RingConfig::load();
70 Ok(config.world_size)
71 }
72
73 #[cfg(not(any(feature = "cuda", feature = "ring")))]
74 Ok(1)
75}
76
77pub fn use_nccl() -> bool {
78 (std::env::var("MISTRALRS_NO_NCCL").is_err()
79 || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
80 && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
81}
82
83#[derive(Debug)]
85pub enum Comm {
86 #[cfg(all(feature = "cuda", feature = "nccl"))]
87 Nccl(nccl::NcclComm),
88 #[cfg(feature = "ring")]
89 Ring(ring::RingComm),
90 Dummy(dummy::DummyComm),
91}
92
93impl Comm {
94 pub fn from_device(
95 id: Id,
96 dev: &candle_core::Device,
97 rank: usize,
98 world_size: usize,
99 ) -> Result<Self> {
100 #[cfg(all(feature = "cuda", feature = "nccl"))]
101 if use_nccl() {
102 return Ok(Self::Nccl(nccl::NcclComm::from_device(
103 id, dev, rank, world_size,
104 )?));
105 }
106
107 #[cfg(feature = "ring")]
108 {
109 return Ok(Self::Ring(ring::RingComm::from_device(
110 id, dev, rank, world_size,
111 )?));
112 }
113
114 #[allow(unreachable_code)]
115 Ok(Self::Dummy(dummy::DummyComm::from_device(
116 id, dev, rank, world_size,
117 )?))
118 }
119
120 pub fn rank(&self) -> usize {
121 match self {
122 #[cfg(all(feature = "cuda", feature = "nccl"))]
123 Self::Nccl(comm) => comm.rank(),
124 #[cfg(feature = "ring")]
125 Self::Ring(comm) => comm.rank(),
126 Self::Dummy(comm) => comm.rank(),
127 }
128 }
129
130 pub fn world_size(&self) -> usize {
131 match self {
132 #[cfg(all(feature = "cuda", feature = "nccl"))]
133 Self::Nccl(comm) => comm.world_size(),
134 #[cfg(feature = "ring")]
135 Self::Ring(comm) => comm.world_size(),
136 Self::Dummy(comm) => comm.world_size(),
137 }
138 }
139}
140
141#[derive(Debug, Clone, Copy)]
143pub enum Id {
144 #[cfg(all(feature = "cuda", feature = "nccl"))]
145 Nccl(cudarc::nccl::Id),
146 Dummy,
147}
148
149impl Id {
150 pub fn new() -> Self {
151 #[cfg(all(feature = "cuda", feature = "nccl"))]
152 if use_nccl() {
153 let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
154 return Self::Nccl(id);
155 }
156 Self::Dummy
157 }
158
159 pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
160 #[cfg(all(feature = "cuda", feature = "nccl"))]
161 if use_nccl() {
162 return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
163 }
164 Self::Dummy
165 }
166
167 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
168 match self {
169 #[cfg(all(feature = "cuda", feature = "nccl"))]
170 Self::Nccl(id) => id.internal(),
171 Self::Dummy => {
172 static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
173 &ZEROED_ID
174 }
175 }
176 }
177}
178
179impl Default for Id {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185#[cfg(all(feature = "cuda", feature = "nccl"))]
186use candle_core::cuda::cudarc;
187
188#[cfg(all(feature = "cuda", feature = "nccl"))]
190mod nccl {
191 use candle_core::{cuda::cudarc, Device, Result};
192
193 #[derive(Debug)]
194 pub struct NcclComm {
195 comm: cudarc::nccl::Comm,
196 }
197
198 impl NcclComm {
199 pub fn from_device(
200 id: super::Id,
201 dev: &Device,
202 rank: usize,
203 world_size: usize,
204 ) -> Result<Self> {
205 if !super::use_nccl() {
206 candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
207 }
208 if !world_size.is_power_of_two() {
209 candle_core::bail!(
210 "NCCL backend requires world_size to be a power of 2, got {}",
211 world_size
212 );
213 }
214 let stream = dev.as_cuda_device()?.cuda_stream();
215 assert_eq!(rank, stream.context().ordinal());
216 let nccl_id = match id {
217 super::Id::Nccl(id) => id,
218 _ => panic!("Expected NCCL Id variant"),
219 };
220 Ok(Self {
221 comm: cudarc::nccl::Comm::from_rank(stream, rank, world_size, nccl_id)
222 .map_err(|e| e.0)
223 .expect("Failed to create `Comm`, error code"),
224 })
225 }
226
227 pub fn rank(&self) -> usize {
228 self.comm.rank()
229 }
230
231 pub fn world_size(&self) -> usize {
232 self.comm.world_size()
233 }
234
235 pub fn inner(&self) -> &cudarc::nccl::Comm {
236 &self.comm
237 }
238 }
239
240 unsafe impl Sync for NcclComm {}
242 unsafe impl Send for NcclComm {}
243}
244
245#[cfg(feature = "ring")]
247mod ring {
248 use super::RingConfig;
249 use candle_core::{Device, Result};
250
251 #[derive(Debug)]
252 pub struct RingComm {
253 config: RingConfig,
254 }
255
256 impl RingComm {
257 pub fn from_device(
258 _id: super::Id,
259 _dev: &Device,
260 _rank: usize,
261 _world_size: usize,
262 ) -> Result<Self> {
263 let config = RingConfig::load();
264 if config.world_size < 2 {
266 candle_core::bail!(
267 "Ring backend requires world_size >= 2, got {}",
268 config.world_size
269 );
270 }
271 if config.rank >= config.world_size {
272 candle_core::bail!(
273 "Ring backend invalid config: rank {} >= world_size {}",
274 config.rank,
275 config.world_size
276 );
277 }
278 if !config.world_size.is_power_of_two() {
279 candle_core::bail!(
280 "Ring backend requires world_size to be a power of 2, got {}",
281 config.world_size
282 );
283 }
284 Ok(Self { config })
285 }
286
287 pub fn rank(&self) -> usize {
288 self.config.rank
289 }
290
291 pub fn world_size(&self) -> usize {
292 self.config.world_size
293 }
294
295 pub fn config(&self) -> &RingConfig {
296 &self.config
297 }
298 }
299}
300
301mod dummy {
303 use candle_core::{Device, Result};
304
305 #[derive(Debug)]
306 pub struct DummyComm;
307
308 impl DummyComm {
309 pub fn from_device(
310 _id: super::Id,
311 _dev: &Device,
312 _rank: usize,
313 _world_size: usize,
314 ) -> Result<Self> {
315 Ok(Self)
316 }
317
318 pub fn rank(&self) -> usize {
319 0
320 }
321
322 pub fn world_size(&self) -> usize {
323 1
324 }
325 }
326}
327
328#[derive(Clone, Debug)]
330pub struct SumAllReduce {
331 #[cfg(all(feature = "cuda", feature = "nccl"))]
332 nccl: Option<nccl_ops::SumAllReduce>,
333 #[cfg(feature = "ring")]
334 ring: Option<ring_ops::SumAllReduce>,
335 dummy: Option<dummy_ops::SumAllReduce>,
336}
337
338impl SumAllReduce {
339 pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
340 match &**comm {
341 #[cfg(all(feature = "cuda", feature = "nccl"))]
342 Comm::Nccl(_) => Self {
343 #[cfg(all(feature = "cuda", feature = "nccl"))]
344 nccl: Some(nccl_ops::SumAllReduce::new(comm)),
345 #[cfg(feature = "ring")]
346 ring: None,
347 dummy: None,
348 },
349 #[cfg(feature = "ring")]
350 Comm::Ring(_) => Self {
351 #[cfg(all(feature = "cuda", feature = "nccl"))]
352 nccl: None,
353 #[cfg(feature = "ring")]
354 ring: Some(ring_ops::SumAllReduce::new(comm)),
355 dummy: None,
356 },
357 Comm::Dummy(_) => Self {
358 #[cfg(all(feature = "cuda", feature = "nccl"))]
359 nccl: None,
360 #[cfg(feature = "ring")]
361 ring: None,
362 dummy: Some(dummy_ops::SumAllReduce::new(comm)),
363 },
364 }
365 }
366
367 pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
368 #[cfg(all(feature = "cuda", feature = "nccl"))]
369 if let Some(ref nccl) = self.nccl {
370 return nccl.sum_all_reduce(xs);
371 }
372 #[cfg(feature = "ring")]
373 if let Some(ref ring) = self.ring {
374 return ring.sum_all_reduce(xs);
375 }
376 if let Some(ref dummy) = self.dummy {
377 return dummy.sum_all_reduce(xs);
378 }
379 candle_core::bail!("No valid SumAllReduce implementation available")
380 }
381}
382
383#[derive(Clone, Debug)]
384pub struct AllGather {
385 #[cfg(all(feature = "cuda", feature = "nccl"))]
386 nccl: Option<nccl_ops::AllGather>,
387 #[cfg(feature = "ring")]
388 ring: Option<ring_ops::AllGather>,
389 dummy: Option<dummy_ops::AllGather>,
390}
391
392impl AllGather {
393 pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
394 match &**comm {
395 #[cfg(all(feature = "cuda", feature = "nccl"))]
396 Comm::Nccl(_) => Self {
397 #[cfg(all(feature = "cuda", feature = "nccl"))]
398 nccl: Some(nccl_ops::AllGather::new(comm, dim)),
399 #[cfg(feature = "ring")]
400 ring: None,
401 dummy: None,
402 },
403 #[cfg(feature = "ring")]
404 Comm::Ring(_) => Self {
405 #[cfg(all(feature = "cuda", feature = "nccl"))]
406 nccl: None,
407 #[cfg(feature = "ring")]
408 ring: Some(ring_ops::AllGather::new(comm, dim)),
409 dummy: None,
410 },
411 Comm::Dummy(_) => Self {
412 #[cfg(all(feature = "cuda", feature = "nccl"))]
413 nccl: None,
414 #[cfg(feature = "ring")]
415 ring: None,
416 dummy: Some(dummy_ops::AllGather::new(comm, dim)),
417 },
418 }
419 }
420
421 pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
422 #[cfg(all(feature = "cuda", feature = "nccl"))]
423 if let Some(ref nccl) = self.nccl {
424 return nccl.all_gather(xs);
425 }
426 #[cfg(feature = "ring")]
427 if let Some(ref ring) = self.ring {
428 return ring.all_gather(xs);
429 }
430 if let Some(ref dummy) = self.dummy {
431 return dummy.all_gather(xs);
432 }
433 candle_core::bail!("No valid AllGather implementation available")
434 }
435}
436
437#[cfg(all(feature = "cuda", feature = "nccl"))]
439mod nccl_ops {
440 use std::{fmt::Debug, sync::Arc};
441
442 use candle_core::{
443 backend::BackendStorage, cuda::cudarc, CpuStorage, CustomOp1, DType, Layout, Result, Shape,
444 Tensor,
445 };
446
447 #[derive(Clone, Debug)]
448 pub struct SumAllReduce {
449 comm: Arc<super::Comm>,
450 }
451
452 impl SumAllReduce {
453 pub fn new(comm: &Arc<super::Comm>) -> Self {
454 Self { comm: comm.clone() }
455 }
456 }
457
458 impl SumAllReduce {
459 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
460 xs.apply_op1_no_bwd(self)
461 }
462 }
463
464 impl CustomOp1 for SumAllReduce {
465 fn name(&self) -> &'static str {
466 "SumAllReduce"
467 }
468
469 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
470 candle_core::bail!("SumAllReduce is never used on cpu")
471 }
472
473 fn cuda_fwd(
474 &self,
475 s: &candle_core::CudaStorage,
476 l: &Layout,
477 ) -> Result<(candle_core::CudaStorage, Shape)> {
478 use cudarc::nccl::ReduceOp;
479 use half::{bf16, f16};
480
481 let elem_count = l.shape().elem_count();
482 let dev = s.device().clone();
483
484 match self.comm.as_ref() {
485 super::Comm::Nccl(nccl_comm) => {
486 let dst = match s.dtype() {
487 DType::BF16 => {
488 let s = s.as_cuda_slice::<bf16>()?;
489 let s = match l.contiguous_offsets() {
490 Some((0, l)) if l == s.len() => s,
491 Some(_) | None => candle_core::bail!("input has to be contiguous"),
492 };
493 assert!(elem_count > 0);
494 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
495 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
496 nccl_comm
497 .inner()
498 .all_reduce(s, &mut dst, &ReduceOp::Sum)
499 .map_err(candle_core::Error::debug)?;
500 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
501 }
502 DType::F16 => {
503 let s = s.as_cuda_slice::<f16>()?;
504 let s = match l.contiguous_offsets() {
505 Some((0, l)) if l == s.len() => s,
506 Some(_) | None => candle_core::bail!("input has to be contiguous"),
507 };
508 assert!(elem_count > 0);
509 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
510 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
511 nccl_comm
512 .inner()
513 .all_reduce(s, &mut dst, &ReduceOp::Sum)
514 .map_err(candle_core::Error::debug)?;
515 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
516 }
517 DType::F32 => {
518 let s = s.as_cuda_slice::<f32>()?;
519 let s = match l.contiguous_offsets() {
520 Some((0, l)) if l == s.len() => s,
521 Some(_) | None => candle_core::bail!("input has to be contiguous"),
522 };
523 assert!(elem_count > 0);
524 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
525 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
526 nccl_comm
527 .inner()
528 .all_reduce(s, &mut dst, &ReduceOp::Sum)
529 .map_err(candle_core::Error::debug)?;
530 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
531 }
532 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
533 };
534 Ok((dst, l.shape().clone()))
535 }
536 _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
537 }
538 }
539 }
540
541 #[derive(Clone, Debug)]
542 pub struct AllGather {
543 comm: Arc<super::Comm>,
544 dim: usize,
545 }
546
547 impl AllGather {
548 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
549 Self {
550 comm: comm.clone(),
551 dim,
552 }
553 }
554 }
555
556 impl AllGather {
557 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
558 xs.apply_op1_no_bwd(self)
559 }
560 }
561
562 impl CustomOp1 for AllGather {
563 fn name(&self) -> &'static str {
564 "AllGather"
565 }
566
567 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
568 candle_core::bail!("AllGather is never used on cpu")
569 }
570
571 fn cuda_fwd(
572 &self,
573 s: &candle_core::CudaStorage,
574 l: &Layout,
575 ) -> Result<(candle_core::CudaStorage, Shape)> {
576 use half::{bf16, f16};
577
578 let mut out_shape = l.shape().dims().to_vec();
579 out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
580 let out_shape = Shape::from(out_shape);
581
582 let elem_count = out_shape.elem_count();
583 let dev = s.device().clone();
584
585 match self.comm.as_ref() {
586 super::Comm::Nccl(nccl_comm) => {
587 let dst = match s.dtype() {
588 DType::BF16 => {
589 let s = s.as_cuda_slice::<bf16>()?;
590 let s = match l.contiguous_offsets() {
591 Some((0, l)) if l == s.len() => s,
592 Some(_) | None => candle_core::bail!("input has to be contiguous"),
593 };
594 assert!(elem_count > 0);
595 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
596 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
597 nccl_comm
598 .inner()
599 .all_gather(s, &mut dst)
600 .map_err(candle_core::Error::debug)?;
601 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
602 }
603 DType::F16 => {
604 let s = s.as_cuda_slice::<f16>()?;
605 let s = match l.contiguous_offsets() {
606 Some((0, l)) if l == s.len() => s,
607 Some(_) | None => candle_core::bail!("input has to be contiguous"),
608 };
609 assert!(elem_count > 0);
610 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
611 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
612 nccl_comm
613 .inner()
614 .all_gather(s, &mut dst)
615 .map_err(candle_core::Error::debug)?;
616 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
617 }
618 DType::F32 => {
619 let s = s.as_cuda_slice::<f32>()?;
620 let s = match l.contiguous_offsets() {
621 Some((0, l)) if l == s.len() => s,
622 Some(_) | None => candle_core::bail!("input has to be contiguous"),
623 };
624 assert!(elem_count > 0);
625 assert_eq!(dev.cuda_stream().context().ordinal(), nccl_comm.rank());
626 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
627 nccl_comm
628 .inner()
629 .all_gather(s, &mut dst)
630 .map_err(candle_core::Error::debug)?;
631 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
632 }
633 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
634 };
635 Ok((dst, out_shape))
636 }
637 _ => candle_core::bail!("AllGather requires NCCL backend"),
638 }
639 }
640 }
641}
642
643#[cfg(feature = "ring")]
645mod ring_ops {
646 use std::{
647 collections::HashMap,
648 sync::{Arc, Mutex, OnceLock},
649 time::{Duration, Instant},
650 };
651
652 use std::io::{Read, Write};
653 use std::net::{TcpListener, TcpStream};
654
655 type SharedTcpStream = Arc<Mutex<TcpStream>>;
657 type LeftRight = (SharedTcpStream, SharedTcpStream);
658
659 use candle_core::{
660 backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
661 };
662
663 use super::RingConfig;
664
665 static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
667
668 fn get_ring_streams(config: &RingConfig) -> LeftRight {
669 LEFT_RIGHT_STREAMS
670 .get_or_init(|| {
671 let cur_port = config.port;
672
673 let right_ip = config.right_ip();
674 let right_port = config.right_port;
675
676 let left_listener =
677 TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
678
679 let start = Instant::now();
680 let right = loop {
682 match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
683 Ok(s) => break s,
684 Err(_) if start.elapsed() > Duration::from_secs(10) => {
685 panic!("Failed to connect to right node due to 10-second timeout");
686 }
687 Err(_) => continue,
688 }
689 };
690
691 let (left, _) = left_listener.accept().expect("accept left neighbour");
693
694 left.set_nodelay(true).unwrap();
695 left.set_nonblocking(false).unwrap();
696 right.set_nodelay(true).unwrap();
697 right.set_nonblocking(false).unwrap();
698
699 (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
700 })
701 .clone()
702 }
703
704 #[derive(Clone, Debug)]
705 pub struct SumAllReduce {
706 left: SharedTcpStream,
707 right: SharedTcpStream,
708 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
709 }
710
711 impl SumAllReduce {
712 pub fn new(comm: &Arc<super::Comm>) -> Self {
713 match &**comm {
714 super::Comm::Ring(ring_comm) => {
715 let (left, right) = get_ring_streams(ring_comm.config());
716 Self {
717 left,
718 right,
719 buffers: Arc::new(Mutex::new(HashMap::new())),
720 }
721 }
722 _ => panic!("SumAllReduce requires Ring backend"),
723 }
724 }
725
726 fn run<T: WithDType + Copy>(
727 &self,
728 x: &[T],
729 dims: &[usize],
730 device: &Device,
731 ) -> Result<Tensor> {
732 let nbytes = x.len() * std::mem::size_of_val(x);
733
734 let right = self.right.clone();
737 let left = self.left.clone();
738
739 let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
741
742 let mut buffers_guard = self.buffers.lock().map_err(|e| {
744 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
745 })?;
746 let recv_buf = buffers_guard
747 .entry(nbytes)
748 .or_insert_with(|| vec![0u8; nbytes]);
749
750 let mut right_guard = right.lock().map_err(|e| {
752 candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
753 })?;
754 let mut left_guard = left.lock().map_err(|e| {
755 candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
756 })?;
757
758 if nbytes <= 8 * 1024 {
763 right_guard
765 .write_all(data_bytes)
766 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
767
768 left_guard
769 .read_exact(recv_buf)
770 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
771 } else {
772 const CHUNK_SIZE: usize = 64 * 1024; let mut offset = 0;
775
776 while offset < nbytes {
777 let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
778
779 right_guard
781 .write_all(&data_bytes[offset..offset + len])
782 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
783
784 left_guard
786 .read_exact(&mut recv_buf[offset..offset + len])
787 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
788
789 offset += len;
790 }
791 }
792
793 drop(left_guard);
794 drop(right_guard);
795
796 let received: &[T] =
799 unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
800
801 Tensor::from_slice(received, dims, device)
802 }
803
804 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
805 let storage = xs.storage_and_layout().0;
806 let cpu_storage = match &*storage {
807 Storage::Cpu(storage) => storage,
808 Storage::Cuda(storage) => &storage.to_cpu_storage()?,
809 Storage::Metal(storage) => &storage.to_cpu_storage()?,
810 };
811
812 let delta = match cpu_storage {
813 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
814 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
815 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
816 _ => candle_core::bail!("Unsupported dtype for ring backend"),
817 };
818
819 xs + delta
820 }
821 }
822
823 #[derive(Clone, Debug)]
824 pub struct AllGather {
825 left: SharedTcpStream,
826 right: SharedTcpStream,
827 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
828 dim: usize,
829 world_size: usize,
830 rank: usize,
831 }
832
833 impl AllGather {
834 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
835 match &**comm {
836 super::Comm::Ring(ring_comm) => {
837 let (left, right) = get_ring_streams(ring_comm.config());
838 Self {
839 left,
840 right,
841 buffers: Arc::new(Mutex::new(HashMap::new())),
842 dim,
843 world_size: ring_comm.world_size(),
844 rank: ring_comm.rank(),
845 }
846 }
847 _ => panic!("AllGather requires Ring backend"),
848 }
849 }
850
851 fn run<T: WithDType + Copy + Default>(
852 &self,
853 x: &[T],
854 dims: &[usize],
855 device: &Device,
856 ) -> Result<Tensor> {
857 if self.dim >= dims.len() {
859 candle_core::bail!(
860 "AllGather: invalid dimension {} for tensor of rank {}",
861 self.dim,
862 dims.len()
863 );
864 }
865 let elem_cnt = x.len();
866 let nbytes = elem_cnt * std::mem::size_of_val(x);
867
868 let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
870
871 let start = self.rank * elem_cnt;
873 out[start..start + elem_cnt].copy_from_slice(x);
874
875 let right = self.right.clone();
876 let left = self.left.clone();
877 let mut send_piece: &[T] = x;
878
879 for step in 0..(self.world_size - 1) {
880 let bytes =
882 unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
883 {
884 let mut rg = right.lock().map_err(|e| {
885 candle_core::Error::msg(format!(
886 "Failed to lock right stream mutex: {:?}",
887 e
888 ))
889 })?;
890 rg.write_all(bytes)
891 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
892 }
893
894 let mut bg = self.buffers.lock().map_err(|e| {
896 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
897 })?;
898 let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
899 {
900 let mut lg = left.lock().map_err(|e| {
901 candle_core::Error::msg(format!(
902 "Failed to lock left stream mutex: {:?}",
903 e
904 ))
905 })?;
906 lg.read_exact(buf)
907 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
908 }
909 let recv_piece: &[T] =
910 unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
911
912 let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
914 let dst = src_rank * elem_cnt;
915 out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
916
917 send_piece = recv_piece;
919 }
920
921 let mut out_dims = dims.to_vec();
922 out_dims[self.dim] *= self.world_size;
923 Tensor::from_slice(&out, out_dims, device)
924 }
925
926 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
927 let storage = xs.storage_and_layout().0;
928 let cpu_storage = match &*storage {
929 Storage::Cpu(s) => s,
930 Storage::Cuda(s) => &s.to_cpu_storage()?,
931 Storage::Metal(s) => &s.to_cpu_storage()?,
932 };
933
934 match cpu_storage {
935 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
936 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
937 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
938 _ => candle_core::bail!("Unsupported dtype for ring backend"),
939 }
940 }
941 }
942}
943
944mod dummy_ops {
946 use candle_core::{Result, Tensor};
947 use std::sync::Arc;
948
949 #[derive(Clone, Debug)]
950 pub struct SumAllReduce;
951
952 impl SumAllReduce {
953 pub fn new(_comm: &Arc<super::Comm>) -> Self {
954 Self
955 }
956
957 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
958 Ok(xs.clone())
959 }
960 }
961
962 #[derive(Clone, Debug)]
963 pub struct AllGather;
964
965 impl AllGather {
966 pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
967 Self
968 }
969
970 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
971 Ok(xs.clone())
972 }
973 }
974}