1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11 master_ip: Option<String>,
12 pub master_port: u16,
13 pub port: u16,
14 pub right_port: u16,
15 right_ip: Option<String>,
16 pub rank: usize,
17 pub world_size: usize,
18}
19
20impl RingConfig {
21 pub fn load() -> Self {
23 let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24 let config: RingConfig = serde_json::from_reader(
25 &File::open(config_json).expect("Could not access Ring config JSON"),
26 )
27 .expect("Invalid JSON config");
28
29 if config.master_ip.is_none() && !config.is_master_rank() {
30 panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31 }
32 config
33 }
34
35 pub fn is_master_rank(&self) -> bool {
36 self.rank == 0
37 }
38
39 pub fn master_ip(&self) -> String {
40 self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41 }
42
43 pub fn right_ip(&self) -> String {
44 self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45 }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49 fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53 fn wait(&self) -> Result<()> {
54 Barrier::wait(self);
55 Ok(())
56 }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60 #[cfg(feature = "cuda")]
61 {
62 use candle_core::cuda::WrapErr;
63 candle_core::cuda::cudarc::driver::result::device::get_count()
64 .w()
65 .map(|x| x as usize)
66 }
67 #[cfg(feature = "ring")]
68 {
69 let config = RingConfig::load();
70 Ok(config.world_size)
71 }
72
73 #[cfg(not(any(feature = "cuda", feature = "ring")))]
74 Ok(1)
75}
76
77pub fn use_nccl() -> bool {
78 (std::env::var("MISTRALRS_NO_NCCL").is_err()
79 || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
80 && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
81}
82
83#[derive(Debug)]
85pub enum Comm {
86 #[cfg(all(feature = "cuda", feature = "nccl"))]
87 Nccl(nccl::NcclComm),
88 #[cfg(feature = "ring")]
89 Ring(ring::RingComm),
90 Dummy(dummy::DummyComm),
91}
92
93impl Comm {
94 pub fn from_device(
95 id: Id,
96 dev: &candle_core::Device,
97 rank: usize,
98 world_size: usize,
99 ) -> Result<Self> {
100 #[cfg(all(feature = "cuda", feature = "nccl"))]
101 if use_nccl() {
102 return Ok(Self::Nccl(nccl::NcclComm::from_device(
103 id, dev, rank, world_size,
104 )?));
105 }
106
107 #[cfg(feature = "ring")]
108 {
109 return Ok(Self::Ring(ring::RingComm::from_device(
110 id, dev, rank, world_size,
111 )?));
112 }
113
114 #[allow(unreachable_code)]
115 Ok(Self::Dummy(dummy::DummyComm::from_device(
116 id, dev, rank, world_size,
117 )?))
118 }
119
120 pub fn rank(&self) -> usize {
121 match self {
122 #[cfg(all(feature = "cuda", feature = "nccl"))]
123 Self::Nccl(comm) => comm.rank(),
124 #[cfg(feature = "ring")]
125 Self::Ring(comm) => comm.rank(),
126 Self::Dummy(comm) => comm.rank(),
127 }
128 }
129
130 pub fn world_size(&self) -> usize {
131 match self {
132 #[cfg(all(feature = "cuda", feature = "nccl"))]
133 Self::Nccl(comm) => comm.world_size(),
134 #[cfg(feature = "ring")]
135 Self::Ring(comm) => comm.world_size(),
136 Self::Dummy(comm) => comm.world_size(),
137 }
138 }
139}
140
141#[derive(Debug, Clone, Copy)]
143pub enum Id {
144 #[cfg(all(feature = "cuda", feature = "nccl"))]
145 Nccl(cudarc::nccl::Id),
146 Dummy,
147}
148
149impl Id {
150 pub fn new() -> Self {
151 #[cfg(all(feature = "cuda", feature = "nccl"))]
152 if use_nccl() {
153 let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
154 return Self::Nccl(id);
155 }
156 Self::Dummy
157 }
158
159 pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
160 #[cfg(all(feature = "cuda", feature = "nccl"))]
161 if use_nccl() {
162 return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
163 }
164 Self::Dummy
165 }
166
167 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
168 match self {
169 #[cfg(all(feature = "cuda", feature = "nccl"))]
170 Self::Nccl(id) => id.internal(),
171 Self::Dummy => {
172 static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
173 &ZEROED_ID
174 }
175 }
176 }
177}
178
179impl Default for Id {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185#[cfg(all(feature = "cuda", feature = "nccl"))]
186use candle_core::cuda::cudarc;
187
188#[cfg(all(feature = "cuda", feature = "nccl"))]
190mod nccl {
191 use candle_core::{cuda::cudarc, Device, Result};
192
193 #[derive(Debug)]
194 pub struct NcclComm {
195 comm: cudarc::nccl::Comm,
196 }
197
198 impl NcclComm {
199 pub fn from_device(
200 id: super::Id,
201 dev: &Device,
202 rank: usize,
203 world_size: usize,
204 ) -> Result<Self> {
205 if !super::use_nccl() {
206 candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
207 }
208 if !world_size.is_power_of_two() {
209 candle_core::bail!(
210 "NCCL backend requires world_size to be a power of 2, got {}",
211 world_size
212 );
213 }
214 let device = dev.as_cuda_device()?.cuda_device();
215 assert_eq!(rank, device.ordinal());
216 let nccl_id = match id {
217 super::Id::Nccl(id) => id,
218 _ => panic!("Expected NCCL Id variant"),
219 };
220 Ok(Self {
221 comm: cudarc::nccl::Comm::from_rank(device, rank, world_size, nccl_id)
222 .map_err(|e| e.0)
223 .expect("Failed to create `Comm`, error code"),
224 })
225 }
226
227 pub fn rank(&self) -> usize {
228 self.comm.rank()
229 }
230
231 pub fn world_size(&self) -> usize {
232 self.comm.world_size()
233 }
234
235 pub fn inner(&self) -> &cudarc::nccl::Comm {
236 &self.comm
237 }
238 }
239
240 unsafe impl Sync for NcclComm {}
242 unsafe impl Send for NcclComm {}
243}
244
245#[cfg(feature = "ring")]
247mod ring {
248 use super::RingConfig;
249 use candle_core::{Device, Result};
250
251 #[derive(Debug)]
252 pub struct RingComm {
253 config: RingConfig,
254 }
255
256 impl RingComm {
257 pub fn from_device(
258 _id: super::Id,
259 _dev: &Device,
260 _rank: usize,
261 _world_size: usize,
262 ) -> Result<Self> {
263 let config = RingConfig::load();
264 if config.world_size < 2 {
266 candle_core::bail!(
267 "Ring backend requires world_size >= 2, got {}",
268 config.world_size
269 );
270 }
271 if config.rank >= config.world_size {
272 candle_core::bail!(
273 "Ring backend invalid config: rank {} >= world_size {}",
274 config.rank,
275 config.world_size
276 );
277 }
278 if !config.world_size.is_power_of_two() {
279 candle_core::bail!(
280 "Ring backend requires world_size to be a power of 2, got {}",
281 config.world_size
282 );
283 }
284 Ok(Self { config })
285 }
286
287 pub fn rank(&self) -> usize {
288 self.config.rank
289 }
290
291 pub fn world_size(&self) -> usize {
292 self.config.world_size
293 }
294
295 pub fn config(&self) -> &RingConfig {
296 &self.config
297 }
298 }
299}
300
301mod dummy {
303 use candle_core::{Device, Result};
304
305 #[derive(Debug)]
306 pub struct DummyComm;
307
308 impl DummyComm {
309 pub fn from_device(
310 _id: super::Id,
311 _dev: &Device,
312 _rank: usize,
313 _world_size: usize,
314 ) -> Result<Self> {
315 Ok(Self)
316 }
317
318 pub fn rank(&self) -> usize {
319 0
320 }
321
322 pub fn world_size(&self) -> usize {
323 1
324 }
325 }
326}
327
328#[derive(Clone, Debug)]
330pub struct SumAllReduce {
331 #[cfg(all(feature = "cuda", feature = "nccl"))]
332 nccl: Option<nccl_ops::SumAllReduce>,
333 #[cfg(feature = "ring")]
334 ring: Option<ring_ops::SumAllReduce>,
335 dummy: Option<dummy_ops::SumAllReduce>,
336}
337
338impl SumAllReduce {
339 pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
340 match &**comm {
341 #[cfg(all(feature = "cuda", feature = "nccl"))]
342 Comm::Nccl(_) => Self {
343 #[cfg(all(feature = "cuda", feature = "nccl"))]
344 nccl: Some(nccl_ops::SumAllReduce::new(comm)),
345 #[cfg(feature = "ring")]
346 ring: None,
347 dummy: None,
348 },
349 #[cfg(feature = "ring")]
350 Comm::Ring(_) => Self {
351 #[cfg(all(feature = "cuda", feature = "nccl"))]
352 nccl: None,
353 #[cfg(feature = "ring")]
354 ring: Some(ring_ops::SumAllReduce::new(comm)),
355 dummy: None,
356 },
357 Comm::Dummy(_) => Self {
358 #[cfg(all(feature = "cuda", feature = "nccl"))]
359 nccl: None,
360 #[cfg(feature = "ring")]
361 ring: None,
362 dummy: Some(dummy_ops::SumAllReduce::new(comm)),
363 },
364 }
365 }
366
367 pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
368 #[cfg(all(feature = "cuda", feature = "nccl"))]
369 if let Some(ref nccl) = self.nccl {
370 return nccl.sum_all_reduce(xs);
371 }
372 #[cfg(feature = "ring")]
373 if let Some(ref ring) = self.ring {
374 return ring.sum_all_reduce(xs);
375 }
376 if let Some(ref dummy) = self.dummy {
377 return dummy.sum_all_reduce(xs);
378 }
379 candle_core::bail!("No valid SumAllReduce implementation available")
380 }
381}
382
383#[derive(Clone, Debug)]
384pub struct AllGather {
385 #[cfg(all(feature = "cuda", feature = "nccl"))]
386 nccl: Option<nccl_ops::AllGather>,
387 #[cfg(feature = "ring")]
388 ring: Option<ring_ops::AllGather>,
389 dummy: Option<dummy_ops::AllGather>,
390}
391
392impl AllGather {
393 pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
394 match &**comm {
395 #[cfg(all(feature = "cuda", feature = "nccl"))]
396 Comm::Nccl(_) => Self {
397 #[cfg(all(feature = "cuda", feature = "nccl"))]
398 nccl: Some(nccl_ops::AllGather::new(comm, dim)),
399 #[cfg(feature = "ring")]
400 ring: None,
401 dummy: None,
402 },
403 #[cfg(feature = "ring")]
404 Comm::Ring(_) => Self {
405 #[cfg(all(feature = "cuda", feature = "nccl"))]
406 nccl: None,
407 #[cfg(feature = "ring")]
408 ring: Some(ring_ops::AllGather::new(comm, dim)),
409 dummy: None,
410 },
411 Comm::Dummy(_) => Self {
412 #[cfg(all(feature = "cuda", feature = "nccl"))]
413 nccl: None,
414 #[cfg(feature = "ring")]
415 ring: None,
416 dummy: Some(dummy_ops::AllGather::new(comm, dim)),
417 },
418 }
419 }
420
421 pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
422 #[cfg(all(feature = "cuda", feature = "nccl"))]
423 if let Some(ref nccl) = self.nccl {
424 return nccl.all_gather(xs);
425 }
426 #[cfg(feature = "ring")]
427 if let Some(ref ring) = self.ring {
428 return ring.all_gather(xs);
429 }
430 if let Some(ref dummy) = self.dummy {
431 return dummy.all_gather(xs);
432 }
433 candle_core::bail!("No valid AllGather implementation available")
434 }
435}
436
437#[cfg(all(feature = "cuda", feature = "nccl"))]
439mod nccl_ops {
440 use std::{fmt::Debug, sync::Arc};
441
442 use candle_core::{
443 backend::BackendStorage, cuda::cudarc, cuda_backend::WrapErr, CpuStorage, CustomOp1, DType,
444 Layout, Result, Shape, Tensor,
445 };
446
447 #[derive(Clone, Debug)]
448 pub struct SumAllReduce {
449 comm: Arc<super::Comm>,
450 }
451
452 impl SumAllReduce {
453 pub fn new(comm: &Arc<super::Comm>) -> Self {
454 Self { comm: comm.clone() }
455 }
456 }
457
458 impl SumAllReduce {
459 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
460 xs.apply_op1_no_bwd(self)
461 }
462 }
463
464 impl CustomOp1 for SumAllReduce {
465 fn name(&self) -> &'static str {
466 "SumAllReduce"
467 }
468
469 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
470 candle_core::bail!("SumAllReduce is never used on cpu")
471 }
472
473 fn cuda_fwd(
474 &self,
475 s: &candle_core::CudaStorage,
476 l: &Layout,
477 ) -> Result<(candle_core::CudaStorage, Shape)> {
478 use cudarc::{driver::DeviceSlice, nccl::ReduceOp};
479 use half::{bf16, f16};
480
481 let elem_count = l.shape().elem_count();
482 let dev = s.device().clone();
483
484 match self.comm.as_ref() {
485 super::Comm::Nccl(nccl_comm) => {
486 let dst = match s.dtype() {
487 DType::BF16 => {
488 let s = s.as_cuda_slice::<bf16>()?;
489 let s = match l.contiguous_offsets() {
490 Some((0, l)) if l == s.len() => s,
491 Some(_) | None => candle_core::bail!("input has to be contiguous"),
492 };
493 assert_eq!(dev.ordinal(), nccl_comm.rank());
494 assert!(elem_count > 0);
495 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
496 nccl_comm
497 .inner()
498 .all_reduce(s, &mut dst, &ReduceOp::Sum)
499 .map_err(candle_core::Error::debug)?;
500 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
501 }
502 DType::F16 => {
503 let s = s.as_cuda_slice::<f16>()?;
504 let s = match l.contiguous_offsets() {
505 Some((0, l)) if l == s.len() => s,
506 Some(_) | None => candle_core::bail!("input has to be contiguous"),
507 };
508 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
509 nccl_comm
510 .inner()
511 .all_reduce(s, &mut dst, &ReduceOp::Sum)
512 .map_err(candle_core::Error::debug)?;
513 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
514 }
515 DType::F32 => {
516 let s = s.as_cuda_slice::<f32>()?;
517 let s = match l.contiguous_offsets() {
518 Some((0, l)) if l == s.len() => s,
519 Some(_) | None => candle_core::bail!("input has to be contiguous"),
520 };
521 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
522 nccl_comm
523 .inner()
524 .all_reduce(s, &mut dst, &ReduceOp::Sum)
525 .map_err(candle_core::Error::debug)?;
526 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
527 }
528 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
529 };
530 Ok((dst, l.shape().clone()))
531 }
532 _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
533 }
534 }
535 }
536
537 #[derive(Clone, Debug)]
538 pub struct AllGather {
539 comm: Arc<super::Comm>,
540 dim: usize,
541 }
542
543 impl AllGather {
544 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
545 Self {
546 comm: comm.clone(),
547 dim,
548 }
549 }
550 }
551
552 impl AllGather {
553 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
554 xs.apply_op1_no_bwd(self)
555 }
556 }
557
558 impl CustomOp1 for AllGather {
559 fn name(&self) -> &'static str {
560 "AllGather"
561 }
562
563 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
564 candle_core::bail!("AllGather is never used on cpu")
565 }
566
567 fn cuda_fwd(
568 &self,
569 s: &candle_core::CudaStorage,
570 l: &Layout,
571 ) -> Result<(candle_core::CudaStorage, Shape)> {
572 use cudarc::driver::DeviceSlice;
573 use half::{bf16, f16};
574
575 let mut out_shape = l.shape().dims().to_vec();
576 out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
577 let out_shape = Shape::from(out_shape);
578
579 let elem_count = out_shape.elem_count();
580 let dev = s.device().clone();
581
582 match self.comm.as_ref() {
583 super::Comm::Nccl(nccl_comm) => {
584 let dst = match s.dtype() {
585 DType::BF16 => {
586 let s = s.as_cuda_slice::<bf16>()?;
587 let s = match l.contiguous_offsets() {
588 Some((0, l)) if l == s.len() => s,
589 Some(_) | None => candle_core::bail!("input has to be contiguous"),
590 };
591 assert_eq!(dev.ordinal(), nccl_comm.rank());
592 assert!(elem_count > 0);
593 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
594 nccl_comm
595 .inner()
596 .all_gather(s, &mut dst)
597 .map_err(candle_core::Error::debug)?;
598 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
599 }
600 DType::F16 => {
601 let s = s.as_cuda_slice::<f16>()?;
602 let s = match l.contiguous_offsets() {
603 Some((0, l)) if l == s.len() => s,
604 Some(_) | None => candle_core::bail!("input has to be contiguous"),
605 };
606 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
607 nccl_comm
608 .inner()
609 .all_gather(s, &mut dst)
610 .map_err(candle_core::Error::debug)?;
611 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
612 }
613 DType::F32 => {
614 let s = s.as_cuda_slice::<f32>()?;
615 let s = match l.contiguous_offsets() {
616 Some((0, l)) if l == s.len() => s,
617 Some(_) | None => candle_core::bail!("input has to be contiguous"),
618 };
619 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
620 nccl_comm
621 .inner()
622 .all_gather(s, &mut dst)
623 .map_err(candle_core::Error::debug)?;
624 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
625 }
626 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
627 };
628 Ok((dst, out_shape))
629 }
630 _ => candle_core::bail!("AllGather requires NCCL backend"),
631 }
632 }
633 }
634}
635
636#[cfg(feature = "ring")]
638mod ring_ops {
639 use std::{
640 collections::HashMap,
641 sync::{Arc, Mutex, OnceLock},
642 time::{Duration, Instant},
643 };
644
645 use std::io::{Read, Write};
646 use std::net::{TcpListener, TcpStream};
647
648 type SharedTcpStream = Arc<Mutex<TcpStream>>;
650 type LeftRight = (SharedTcpStream, SharedTcpStream);
651
652 use candle_core::{
653 backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
654 };
655
656 use super::RingConfig;
657
658 static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
660
661 fn get_ring_streams(config: &RingConfig) -> LeftRight {
662 LEFT_RIGHT_STREAMS
663 .get_or_init(|| {
664 let cur_port = config.port;
665
666 let right_ip = config.right_ip();
667 let right_port = config.right_port;
668
669 let left_listener =
670 TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
671
672 let start = Instant::now();
673 let right = loop {
675 match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
676 Ok(s) => break s,
677 Err(_) if start.elapsed() > Duration::from_secs(10) => {
678 panic!("Failed to connect to right node due to 10-second timeout");
679 }
680 Err(_) => continue,
681 }
682 };
683
684 let (left, _) = left_listener.accept().expect("accept left neighbour");
686
687 left.set_nodelay(true).unwrap();
688 left.set_nonblocking(false).unwrap();
689 right.set_nodelay(true).unwrap();
690 right.set_nonblocking(false).unwrap();
691
692 (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
693 })
694 .clone()
695 }
696
697 #[derive(Clone, Debug)]
698 pub struct SumAllReduce {
699 left: SharedTcpStream,
700 right: SharedTcpStream,
701 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
702 }
703
704 impl SumAllReduce {
705 pub fn new(comm: &Arc<super::Comm>) -> Self {
706 match &**comm {
707 super::Comm::Ring(ring_comm) => {
708 let (left, right) = get_ring_streams(ring_comm.config());
709 Self {
710 left,
711 right,
712 buffers: Arc::new(Mutex::new(HashMap::new())),
713 }
714 }
715 _ => panic!("SumAllReduce requires Ring backend"),
716 }
717 }
718
719 fn run<T: WithDType + Copy>(
720 &self,
721 x: &[T],
722 dims: &[usize],
723 device: &Device,
724 ) -> Result<Tensor> {
725 let nbytes = x.len() * std::mem::size_of_val(x);
726
727 let right = self.right.clone();
730 let left = self.left.clone();
731
732 let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
734
735 let mut buffers_guard = self.buffers.lock().map_err(|e| {
737 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
738 })?;
739 let recv_buf = buffers_guard
740 .entry(nbytes)
741 .or_insert_with(|| vec![0u8; nbytes]);
742
743 let mut right_guard = right.lock().map_err(|e| {
745 candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
746 })?;
747 let mut left_guard = left.lock().map_err(|e| {
748 candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
749 })?;
750
751 if nbytes <= 8 * 1024 {
756 right_guard
758 .write_all(data_bytes)
759 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
760
761 left_guard
762 .read_exact(recv_buf)
763 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
764 } else {
765 const CHUNK_SIZE: usize = 64 * 1024; let mut offset = 0;
768
769 while offset < nbytes {
770 let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
771
772 right_guard
774 .write_all(&data_bytes[offset..offset + len])
775 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
776
777 left_guard
779 .read_exact(&mut recv_buf[offset..offset + len])
780 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
781
782 offset += len;
783 }
784 }
785
786 drop(left_guard);
787 drop(right_guard);
788
789 let received: &[T] =
792 unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
793
794 Tensor::from_slice(received, dims, device)
795 }
796
797 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
798 let storage = xs.storage_and_layout().0;
799 let cpu_storage = match &*storage {
800 Storage::Cpu(storage) => storage,
801 Storage::Cuda(storage) => &storage.to_cpu_storage()?,
802 Storage::Metal(storage) => &storage.to_cpu_storage()?,
803 };
804
805 let delta = match cpu_storage {
806 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
807 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
808 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
809 _ => candle_core::bail!("Unsupported dtype for ring backend"),
810 };
811
812 xs + delta
813 }
814 }
815
816 #[derive(Clone, Debug)]
817 pub struct AllGather {
818 left: SharedTcpStream,
819 right: SharedTcpStream,
820 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
821 dim: usize,
822 world_size: usize,
823 rank: usize,
824 }
825
826 impl AllGather {
827 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
828 match &**comm {
829 super::Comm::Ring(ring_comm) => {
830 let (left, right) = get_ring_streams(ring_comm.config());
831 Self {
832 left,
833 right,
834 buffers: Arc::new(Mutex::new(HashMap::new())),
835 dim,
836 world_size: ring_comm.world_size(),
837 rank: ring_comm.rank(),
838 }
839 }
840 _ => panic!("AllGather requires Ring backend"),
841 }
842 }
843
844 fn run<T: WithDType + Copy + Default>(
845 &self,
846 x: &[T],
847 dims: &[usize],
848 device: &Device,
849 ) -> Result<Tensor> {
850 if self.dim >= dims.len() {
852 candle_core::bail!(
853 "AllGather: invalid dimension {} for tensor of rank {}",
854 self.dim,
855 dims.len()
856 );
857 }
858 let elem_cnt = x.len();
859 let nbytes = elem_cnt * std::mem::size_of_val(x);
860
861 let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
863
864 let start = self.rank * elem_cnt;
866 out[start..start + elem_cnt].copy_from_slice(x);
867
868 let right = self.right.clone();
869 let left = self.left.clone();
870 let mut send_piece: &[T] = x;
871
872 for step in 0..(self.world_size - 1) {
873 let bytes =
875 unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
876 {
877 let mut rg = right.lock().map_err(|e| {
878 candle_core::Error::msg(format!(
879 "Failed to lock right stream mutex: {:?}",
880 e
881 ))
882 })?;
883 rg.write_all(bytes)
884 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
885 }
886
887 let mut bg = self.buffers.lock().map_err(|e| {
889 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
890 })?;
891 let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
892 {
893 let mut lg = left.lock().map_err(|e| {
894 candle_core::Error::msg(format!(
895 "Failed to lock left stream mutex: {:?}",
896 e
897 ))
898 })?;
899 lg.read_exact(buf)
900 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
901 }
902 let recv_piece: &[T] =
903 unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
904
905 let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
907 let dst = src_rank * elem_cnt;
908 out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
909
910 send_piece = recv_piece;
912 }
913
914 let mut out_dims = dims.to_vec();
915 out_dims[self.dim] *= self.world_size;
916 Tensor::from_slice(&out, out_dims, device)
917 }
918
919 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
920 let storage = xs.storage_and_layout().0;
921 let cpu_storage = match &*storage {
922 Storage::Cpu(s) => s,
923 Storage::Cuda(s) => &s.to_cpu_storage()?,
924 Storage::Metal(s) => &s.to_cpu_storage()?,
925 };
926
927 match cpu_storage {
928 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
929 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
930 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
931 _ => candle_core::bail!("Unsupported dtype for ring backend"),
932 }
933 }
934 }
935}
936
937mod dummy_ops {
939 use candle_core::{Result, Tensor};
940 use std::sync::Arc;
941
942 #[derive(Clone, Debug)]
943 pub struct SumAllReduce;
944
945 impl SumAllReduce {
946 pub fn new(_comm: &Arc<super::Comm>) -> Self {
947 Self
948 }
949
950 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
951 Ok(xs.clone())
952 }
953 }
954
955 #[derive(Clone, Debug)]
956 pub struct AllGather;
957
958 impl AllGather {
959 pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
960 Self
961 }
962
963 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
964 Ok(xs.clone())
965 }
966 }
967}