mistralrs_quant/distributed/
socket.rs
1use std::{
2 io::{Read, Write},
3 net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
4 slice,
5 sync::{Barrier, Mutex},
6 time::{Duration, Instant},
7};
8
9use super::{BarrierLike, Id};
10use candle_core::Result;
11
12#[derive(Debug)]
14pub struct Server {
15 connections: Vec<TcpStream>,
17 barrier_all: Barrier,
18 barrier_crossnode: Barrier,
19}
20
21impl Server {
22 pub fn new<A: ToSocketAddrs>(addr: &A, n_nodes: usize, n_local_ranks: usize) -> Result<Self> {
24 let listener = TcpListener::bind(addr)?;
25 listener.set_nonblocking(false)?;
26 let start = Instant::now();
27 let mut connections = Vec::with_capacity(n_nodes);
28 while connections.len() < n_nodes {
29 if let Ok((stream, _)) = listener.accept() {
30 stream.set_read_timeout(Some(Duration::from_secs_f32(10.)))?;
31 stream.set_write_timeout(Some(Duration::from_secs_f32(10.)))?;
32
33 connections.push(stream);
34 }
35 if start.elapsed() > Duration::from_secs(10) {
36 candle_core::bail!("Worker did not connect to head node due to timeout: over 10s");
37 }
38 }
39 Ok(Self {
40 connections,
41 barrier_all: Barrier::new(n_local_ranks),
42 barrier_crossnode: Barrier::new(n_local_ranks),
43 })
44 }
45
46 pub fn broadcast_id(&self, id: &Id) -> Result<()> {
48 let body = id.internal();
49 #[allow(clippy::unnecessary_cast)]
51 let body_bytes = unsafe { slice::from_raw_parts(body.as_ptr() as *const u8, body.len()) };
52 for mut stream in &self.connections {
53 stream.write_all(body_bytes)?;
54 stream.flush()?;
55 }
56 Ok(())
57 }
58}
59
60impl BarrierLike for Server {
61 fn wait(&self) -> Result<()> {
62 let res = self.barrier_all.wait();
64
65 if res.is_leader() {
66 for mut stream in &self.connections {
68 stream.write_all(b"g")?;
69 stream.flush()?;
70 }
71 let mut ack_buf = [0u8; 1];
73 for mut stream in &self.connections {
74 stream.read_exact(&mut ack_buf)?;
75 if &ack_buf != b"a" {
76 candle_core::bail!("Did not get Ack from worker node");
77 }
78 }
79 }
80
81 self.barrier_crossnode.wait();
82 Ok(())
83 }
84}
85
86#[derive(Debug)]
89pub struct Client {
90 stream: Mutex<TcpStream>,
91 barrier_all: Barrier,
92 barrier_crossnode: Barrier,
93}
94
95impl Client {
96 pub fn new(addr: SocketAddr, n_local_ranks: usize) -> Result<Self> {
97 let start = Instant::now();
98 loop {
99 let stream = TcpStream::connect(addr);
100 if let Ok(stream) = stream {
101 stream.set_nodelay(true)?;
102 stream.set_nonblocking(false)?;
103
104 stream.set_read_timeout(Some(Duration::from_secs_f32(10.)))?;
105 stream.set_write_timeout(Some(Duration::from_secs_f32(10.)))?;
106
107 return Ok(Self {
108 stream: Mutex::new(stream),
109 barrier_all: Barrier::new(n_local_ranks),
110 barrier_crossnode: Barrier::new(n_local_ranks),
111 });
112 }
113 if start.elapsed() > Duration::from_secs(10) {
114 candle_core::bail!("Failed to connect to head node due to timeout: over 10s");
115 }
116 }
117 }
118
119 pub fn receive_id(&self) -> Result<Id> {
121 let mut stream = self.stream.lock().unwrap();
122 let mut buffer = [0u8; 128];
123 stream.read_exact(&mut buffer)?;
124
125 let mut id_bytes: [core::ffi::c_char; 128] = [0; 128];
126 for (i, &b) in buffer.iter().enumerate() {
127 id_bytes[i] = b as core::ffi::c_char;
128 }
129 Ok(Id::uninit(id_bytes))
130 }
131}
132
133impl BarrierLike for Client {
134 fn wait(&self) -> Result<()> {
135 let res = self.barrier_all.wait();
137
138 if res.is_leader() {
139 let mut stream = self.stream.lock().unwrap();
140 let mut buf = [0u8; 1];
142 stream.read_exact(&mut buf)?;
143 if &buf != b"g" {
144 candle_core::bail!("Did not receive correct barrier signal from head node");
145 }
146 stream.write_all(b"a")?;
148 stream.flush()?;
149 }
150 self.barrier_crossnode.wait();
152 Ok(())
153 }
154}