mistralrs_core/
distributed.rs

1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{RingConfig, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29    if cfg!(feature = "cuda") && !cfg!(feature = "ring") {
30        std::env::var(IS_DAEMON_FLAG).is_ok()
31    } else if cfg!(feature = "ring") {
32        !RingConfig::load().is_master_rank()
33    } else {
34        false
35    }
36}
37
38pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
39    use std::io::BufRead;
40    use std::io::BufReader;
41
42    std::thread::spawn(move || {
43        let rt = Runtime::new().unwrap();
44        rt.block_on(async move {
45            use interprocess::local_socket::traits::Stream;
46            use interprocess::local_socket::Stream as LocalStream;
47
48            loop {
49                let name = ipc_name().unwrap();
50                if let Ok(stream) = LocalStream::connect(name) {
51                    let mut reader = BufReader::new(stream);
52                    let mut buf = String::new();
53                    reader.read_line(&mut buf).unwrap();
54                    let mut req: Request = serde_json::from_str(&buf).unwrap();
55
56                    req = match req {
57                        Request::ReIsq(x) => Request::ReIsq(x),
58                        Request::Terminate => Request::Terminate,
59                        Request::Detokenize(mut x) => {
60                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
61                            x.response = sender;
62                            let req = Request::Detokenize(x);
63
64                            request_sender.send(req).await.unwrap();
65                            let resp = receiver.recv().await.unwrap();
66                            resp.unwrap();
67                            continue;
68                        }
69                        Request::Tokenize(mut x) => {
70                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
71                            x.response = sender;
72                            let req = Request::Tokenize(x);
73
74                            request_sender.send(req).await.unwrap();
75                            let resp = receiver.recv().await.unwrap();
76                            resp.unwrap();
77                            continue;
78                        }
79                        Request::Normal(mut x) => {
80                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
81                            x.is_streaming = false;
82                            x.response = sender;
83                            let req = Request::Normal(x);
84
85                            request_sender.send(req).await.unwrap();
86                            let resp = receiver.recv().await.unwrap();
87                            resp.as_result().unwrap();
88                            continue;
89                        }
90                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
91                    };
92
93                    request_sender.send(req).await.unwrap();
94                }
95            }
96        });
97    });
98}
99
100pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
101    use std::io::BufRead;
102    use std::io::BufReader;
103
104    let ring_config = RingConfig::load();
105
106    let master_ip = ring_config.master_ip();
107    let master_port = ring_config.master_port;
108    std::thread::spawn(move || {
109        let rt = Runtime::new().unwrap();
110        rt.block_on(async move {
111            loop {
112                if let Ok(stream) = TcpStream::connect(format!("{master_ip}:{master_port}")) {
113                    let mut reader = BufReader::new(stream);
114                    let mut buf = String::new();
115                    reader.read_line(&mut buf).unwrap();
116                    let mut req: Request = serde_json::from_str(&buf).unwrap();
117
118                    req = match req {
119                        Request::ReIsq(x) => Request::ReIsq(x),
120                        Request::Terminate => Request::Terminate,
121                        Request::Detokenize(mut x) => {
122                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
123                            x.response = sender;
124                            let req = Request::Detokenize(x);
125
126                            request_sender.send(req).await.unwrap();
127                            let resp = receiver.recv().await.unwrap();
128                            resp.unwrap();
129                            continue;
130                        }
131                        Request::Tokenize(mut x) => {
132                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
133                            x.response = sender;
134                            let req = Request::Tokenize(x);
135
136                            request_sender.send(req).await.unwrap();
137                            let resp = receiver.recv().await.unwrap();
138                            resp.unwrap();
139                            continue;
140                        }
141                        Request::Normal(mut x) => {
142                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
143                            x.is_streaming = false;
144                            x.response = sender;
145                            let req = Request::Normal(x);
146
147                            request_sender.send(req).await.unwrap();
148                            let resp = receiver.recv().await.unwrap();
149                            resp.as_result().unwrap();
150                            continue;
151                        }
152                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
153                    };
154
155                    request_sender.send(req).await.unwrap();
156                }
157            }
158        });
159    });
160}
161
162#[derive(Serialize, Deserialize, Debug)]
163#[serde(transparent)]
164pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
165
166#[derive(Serialize, Deserialize, Debug)]
167pub(crate) enum WorkerTransferData {
168    Init {
169        id: BigCCharArray,
170        worker_rank: usize,
171    },
172}
173
174pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
175    let printname = "mistralrs_daemon.sock";
176    Ok(printname.to_ns_name::<GenericNamespaced>()?)
177}
178
179#[allow(clippy::too_many_arguments)]
180pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
181    dtype: DType,
182    device: &Device,
183    available_devices: &[Device],
184    silent: bool,
185    config: &str,
186    loading_isq: bool,
187    from_uqff: bool,
188    organization: IsqOrganization,
189    model: &T,
190    paths: &dyn ModelPaths,
191) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
192    if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
193        tracing::warn!(
194            "Distributed support was not included in the build, be sure to build with `--features nccl`."
195        );
196    }
197
198    // NCCL case!
199
200    let local_world_size = available_devices.len();
201    let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
202        usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
203    } else {
204        mistralrs_quant::distributed::get_global_tp_size_from_devices()?
205    };
206
207    let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
208    if use_multi_node {
209        info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
210    }
211
212    if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
213        anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
214    }
215
216    info!("Local tensor parallel world size is {local_world_size}");
217    info!("Global tensor parallel world size is {global_world_size}");
218
219    // TP uses parallel pipelines.
220    let name = ipc_name()?;
221    let mut id;
222    let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
223        let payload: WorkerTransferData = serde_json::from_str(&payload)?;
224        let WorkerTransferData::Init {
225            id: new_id,
226            worker_rank,
227        } = payload;
228        id = mistralrs_quant::Id::uninit(new_id.0);
229
230        let mut stream = LocalStream::connect(name)?;
231        stream.write_all(b"ready\n")?;
232        worker_rank + 1
233    } else if cfg!(feature = "ring") {
234        id = mistralrs_quant::Id::new();
235
236        let config = RingConfig::load();
237
238        config.rank
239    } else {
240        id = mistralrs_quant::Id::new();
241        let num_workers = mistralrs_quant::distributed::get_global_tp_size_from_devices()? - 1;
242        let mut children = Vec::new();
243        for worker_rank in 0..num_workers {
244            let exe_path = env::current_exe().expect("Failed to get current exe");
245
246            let args: Vec<String> = env::args().collect();
247
248            let mut cmd = Command::new(exe_path);
249            cmd.args(&args[1..]);
250
251            let data = WorkerTransferData::Init {
252                id: BigCCharArray(*id.internal()),
253                worker_rank,
254            };
255
256            cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
257
258            cmd.stdout(std::process::Stdio::null());
259            cmd.stderr(std::process::Stdio::null());
260            cmd.stdin(std::process::Stdio::null());
261
262            children.push(cmd.spawn().expect("Failed to spawn process"));
263        }
264
265        let listener = ListenerOptions::new().name(name).create_sync()?;
266        let mut ready_count = 0;
267
268        while ready_count < num_workers {
269            let stream = listener.accept()?;
270            let mut reader = BufReader::new(stream);
271            let mut message = String::new();
272            reader.read_line(&mut message)?;
273            if message.trim() == "ready" {
274                ready_count += 1;
275            }
276        }
277        info!("All workers have received the ids!");
278
279        0
280    };
281
282    if use_multi_node {
283        if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
284            let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
285            info!("Head node managing {n_nodes} workers.");
286            let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
287                anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
288            };
289            info!("Head node initializing connection on {port}.");
290            let server = mistralrs_quant::Server::new(
291                &format!("0.0.0.0:{port}"),
292                n_nodes,
293                local_world_size,
294            )?;
295
296            server.broadcast_id(&id)?;
297        } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
298            info!("Worker node connecting to {addr}.");
299            let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
300
301            id = client.receive_id()?;
302        }
303    }
304
305    let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
306        let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
307            anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
308        };
309        let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
310        info!("Worker ID is {node_id}.");
311        (node_id + 1) * local_world_size
312    } else {
313        0
314    };
315
316    // They each block on each other
317    // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html?ncclcomminitrank#ncclcomminitrank
318    let comm = mistralrs_quant::Comm::from_device(
319        id,
320        device,
321        local_rank + rank_offset,
322        global_world_size,
323    )?;
324
325    let make_dummy_regexes = if loading_isq && from_uqff {
326        // Dummy weights for the layers which will be overwritten...
327        Some(std::sync::Arc::new(
328            if matches!(organization, IsqOrganization::MoeExpertsOnly) {
329                model.isq_layer_regexes_moqe(config)?
330            } else {
331                model.isq_layer_regexes(config)?
332            },
333        ))
334    } else {
335        None
336    };
337
338    let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
339        paths.get_weight_filenames().to_vec(),
340        vec![],
341        Some(dtype),
342        &Device::Cpu,
343        vec![],
344        silent,
345        make_dummy_regexes,
346        |_| true,
347        Arc::new(|_| DeviceForLoadTensor::Base),
348    )?;
349
350    info!("Loading all ranks.");
351    // The mapper is specific to this pipeline
352    let mapper = DeviceMapSetting::Nccl {
353        nm_device: available_devices[0].clone(),
354        comm: Arc::new(comm),
355    }
356    .into_mapper(model.num_layers(config)?, device, None)?;
357
358    let sharded_vb = if !loading_isq {
359        sharded_vb.clone().set_device(device.clone())
360    } else {
361        sharded_vb.clone()
362    };
363
364    Ok((mapper, sharded_vb))
365}