mistralrs_quant/distributed/
socket.rsuse std::{
io::{Read, Write},
net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
slice,
sync::{Barrier, Mutex},
time::{Duration, Instant},
};
use super::{BarrierLike, Id};
use candle_core::Result;
#[derive(Debug)]
pub struct Server {
connections: Vec<TcpStream>,
barrier_all: Barrier,
barrier_crossnode: Barrier,
}
impl Server {
pub fn new<A: ToSocketAddrs>(addr: &A, n_nodes: usize, n_local_ranks: usize) -> Result<Self> {
let listener = TcpListener::bind(addr)?;
listener.set_nonblocking(false)?;
let start = Instant::now();
let mut connections = Vec::with_capacity(n_nodes);
while connections.len() < n_nodes {
if let Ok((stream, _)) = listener.accept() {
stream.set_read_timeout(Some(Duration::from_secs_f32(10.)))?;
stream.set_write_timeout(Some(Duration::from_secs_f32(10.)))?;
connections.push(stream);
}
if start.elapsed() > Duration::from_secs(10) {
candle_core::bail!("Worker did not connect to head node due to timeout: over 10s");
}
}
Ok(Self {
connections,
barrier_all: Barrier::new(n_local_ranks),
barrier_crossnode: Barrier::new(n_local_ranks),
})
}
pub fn broadcast_id(&self, id: &Id) -> Result<()> {
let body = id.internal();
let body_bytes = unsafe { slice::from_raw_parts(body.as_ptr() as *const u8, body.len()) };
for mut stream in &self.connections {
stream.write_all(body_bytes)?;
stream.flush()?;
}
Ok(())
}
}
impl BarrierLike for Server {
fn wait(&self) -> Result<()> {
let res = self.barrier_all.wait();
if res.is_leader() {
for mut stream in &self.connections {
stream.write_all(b"g")?;
stream.flush()?;
}
let mut ack_buf = [0u8; 1];
for mut stream in &self.connections {
stream.read_exact(&mut ack_buf)?;
if &ack_buf != b"a" {
candle_core::bail!("Did not get Ack from worker node");
}
}
}
self.barrier_crossnode.wait();
Ok(())
}
}
#[derive(Debug)]
pub struct Client {
stream: Mutex<TcpStream>,
barrier_all: Barrier,
barrier_crossnode: Barrier,
}
impl Client {
pub fn new(addr: SocketAddr, n_local_ranks: usize) -> Result<Self> {
let start = Instant::now();
loop {
let stream = TcpStream::connect(addr);
if let Ok(stream) = stream {
stream.set_nodelay(true)?;
stream.set_nonblocking(false)?;
stream.set_read_timeout(Some(Duration::from_secs_f32(10.)))?;
stream.set_write_timeout(Some(Duration::from_secs_f32(10.)))?;
return Ok(Self {
stream: Mutex::new(stream),
barrier_all: Barrier::new(n_local_ranks),
barrier_crossnode: Barrier::new(n_local_ranks),
});
}
if start.elapsed() > Duration::from_secs(10) {
candle_core::bail!("Failed to connect to head node due to timeout: over 10s");
}
}
}
pub fn receive_id(&self) -> Result<Id> {
let mut stream = self.stream.lock().unwrap();
let mut buffer = [0u8; 128];
stream.read_exact(&mut buffer)?;
let mut id_bytes: [core::ffi::c_char; 128] = [0; 128];
for (i, &b) in buffer.iter().enumerate() {
id_bytes[i] = b as core::ffi::c_char;
}
Ok(Id::uninit(id_bytes))
}
}
impl BarrierLike for Client {
fn wait(&self) -> Result<()> {
let res = self.barrier_all.wait();
if res.is_leader() {
let mut stream = self.stream.lock().unwrap();
let mut buf = [0u8; 1];
stream.read_exact(&mut buf)?;
if &buf != b"g" {
candle_core::bail!("Did not receive correct barrier signal from head node");
}
stream.write_all(b"a")?;
stream.flush()?;
}
self.barrier_crossnode.wait();
Ok(())
}
}