mistralrs_quant/distributed/
socket.rs

1use std::{
2    io::{Read, Write},
3    net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
4    slice,
5    sync::{Barrier, Mutex},
6    time::{Duration, Instant},
7};
8
9use super::{BarrierLike, Id};
10use candle_core::Result;
11
12/// The Server maintains persistent connections.
13#[derive(Debug)]
14pub struct Server {
15    // Persistent TCP connections from each node.
16    connections: Vec<TcpStream>,
17    barrier_all: Barrier,
18    barrier_crossnode: Barrier,
19}
20
21impl Server {
22    /// Binds the listener and then accepts exactly `n_nodes` persistent connections.
23    pub fn new<A: ToSocketAddrs>(addr: &A, n_nodes: usize, n_local_ranks: usize) -> Result<Self> {
24        let listener = TcpListener::bind(addr)?;
25        listener.set_nonblocking(false)?;
26        let start = Instant::now();
27        let mut connections = Vec::with_capacity(n_nodes);
28        while connections.len() < n_nodes {
29            if let Ok((stream, _)) = listener.accept() {
30                stream.set_read_timeout(Some(Duration::from_secs_f32(10.)))?;
31                stream.set_write_timeout(Some(Duration::from_secs_f32(10.)))?;
32
33                connections.push(stream);
34            }
35            if start.elapsed() > Duration::from_secs(10) {
36                candle_core::bail!("Worker did not connect to head node due to timeout: over 10s");
37            }
38        }
39        Ok(Self {
40            connections,
41            barrier_all: Barrier::new(n_local_ranks),
42            barrier_crossnode: Barrier::new(n_local_ranks),
43        })
44    }
45
46    /// Broadcasts the given ID over all persistent connections.
47    pub fn broadcast_id(&self, id: &Id) -> Result<()> {
48        let body = id.internal();
49        // SAFETY: We know the provenance and lifetime of `body` are valid.
50        #[allow(clippy::unnecessary_cast)]
51        let body_bytes = unsafe { slice::from_raw_parts(body.as_ptr() as *const u8, body.len()) };
52        for mut stream in &self.connections {
53            stream.write_all(body_bytes)?;
54            stream.flush()?;
55        }
56        Ok(())
57    }
58}
59
60impl BarrierLike for Server {
61    fn wait(&self) -> Result<()> {
62        // First, synchronize locally.
63        let res = self.barrier_all.wait();
64
65        if res.is_leader() {
66            // Leader sends the barrier signal "g" to every node.
67            for mut stream in &self.connections {
68                stream.write_all(b"g")?;
69                stream.flush()?;
70            }
71            // Now, wait to receive an acknowledgement "a" from every node.
72            let mut ack_buf = [0u8; 1];
73            for mut stream in &self.connections {
74                stream.read_exact(&mut ack_buf)?;
75                if &ack_buf != b"a" {
76                    candle_core::bail!("Did not get Ack from worker node");
77                }
78            }
79        }
80
81        self.barrier_crossnode.wait();
82        Ok(())
83    }
84}
85
86/// The Client holds its persistent connection inside a Mutex so that its barrier
87/// operations can have mutable access to the stream.
88#[derive(Debug)]
89pub struct Client {
90    stream: Mutex<TcpStream>,
91    barrier_all: Barrier,
92    barrier_crossnode: Barrier,
93}
94
95impl Client {
96    pub fn new(addr: SocketAddr, n_local_ranks: usize) -> Result<Self> {
97        let start = Instant::now();
98        loop {
99            let stream = TcpStream::connect(addr);
100            if let Ok(stream) = stream {
101                stream.set_nodelay(true)?;
102                stream.set_nonblocking(false)?;
103
104                stream.set_read_timeout(Some(Duration::from_secs_f32(10.)))?;
105                stream.set_write_timeout(Some(Duration::from_secs_f32(10.)))?;
106
107                return Ok(Self {
108                    stream: Mutex::new(stream),
109                    barrier_all: Barrier::new(n_local_ranks),
110                    barrier_crossnode: Barrier::new(n_local_ranks),
111                });
112            }
113            if start.elapsed() > Duration::from_secs(10) {
114                candle_core::bail!("Failed to connect to head node due to timeout: over 10s");
115            }
116        }
117    }
118
119    /// Receives the broadcasted ID from the persistent stream.
120    pub fn receive_id(&self) -> Result<Id> {
121        let mut stream = self.stream.lock().unwrap();
122        let mut buffer = [0u8; 128];
123        stream.read_exact(&mut buffer)?;
124
125        let mut id_bytes: [core::ffi::c_char; 128] = [0; 128];
126        for (i, &b) in buffer.iter().enumerate() {
127            id_bytes[i] = b as core::ffi::c_char;
128        }
129        Ok(Id::uninit(id_bytes))
130    }
131}
132
133impl BarrierLike for Client {
134    fn wait(&self) -> Result<()> {
135        // Synchronize locally.
136        let res = self.barrier_all.wait();
137
138        if res.is_leader() {
139            let mut stream = self.stream.lock().unwrap();
140            // Read the barrier signal "Go!" from the persistent stream.
141            let mut buf = [0u8; 1];
142            stream.read_exact(&mut buf)?;
143            if &buf != b"g" {
144                candle_core::bail!("Did not receive correct barrier signal from head node");
145            }
146            // Immediately send back an acknowledgement "Ack".
147            stream.write_all(b"a")?;
148            stream.flush()?;
149        }
150        // Synchronize again across local ranks.
151        self.barrier_crossnode.wait();
152        Ok(())
153    }
154}