mistralrs_core/
distributed.rs1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{ShardedSafeTensors, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::process::Command;
14use std::str::FromStr;
15use std::sync::Arc;
16use tracing::info;
17
18use crate::device_map::DeviceMapper;
19use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
20use crate::{DeviceMapSetting, IsqOrganization, ModelPaths};
21
22pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
23
24pub fn is_daemon() -> bool {
25 std::env::var(IS_DAEMON_FLAG).is_ok()
26}
27
28#[derive(Serialize, Deserialize, Debug)]
29#[serde(transparent)]
30pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
31
32#[derive(Serialize, Deserialize, Debug)]
33pub(crate) enum WorkerTransferData {
34 Init {
35 id: BigCCharArray,
36 worker_rank: usize,
37 },
38}
39
40pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
41 let printname = "mistralrs_daemon.sock";
42 Ok(printname.to_ns_name::<GenericNamespaced>()?)
43}
44
45#[allow(clippy::too_many_arguments)]
46pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
47 dtype: DType,
48 device: &Device,
49 load_device: &Device,
50 available_devices: &[Device],
51 config: &str,
52 loading_isq: bool,
53 from_uqff: bool,
54 organization: IsqOrganization,
55 model: &T,
56 paths: &dyn ModelPaths,
57) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
58 #[cfg(not(feature = "nccl"))]
59 tracing::warn!(
60 "NCCL support was included in the build, be sure to build with `--features nccl`."
61 );
62
63 let local_world_size = available_devices.len();
66 let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
67 usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
68 } else {
69 mistralrs_quant::distributed::get_global_tp_size_from_devices()?
70 };
71
72 let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
73 if use_multi_node {
74 info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
75 }
76
77 if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
78 anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
79 }
80
81 info!("Local tensor parallel world size is {local_world_size}");
82 info!("Global tensor parallel world size is {global_world_size}");
83
84 let name = ipc_name()?;
86 let mut id;
87 let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
88 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
89 let WorkerTransferData::Init {
90 id: new_id,
91 worker_rank,
92 } = payload;
93 id = mistralrs_quant::Id::uninit(new_id.0);
94
95 let mut stream = LocalStream::connect(name)?;
96 stream.write_all(b"ready\n")?;
97 worker_rank + 1
98 } else {
99 id = mistralrs_quant::Id::new();
100 let num_workers = mistralrs_quant::distributed::get_global_tp_size_from_devices()? - 1;
101 let mut children = Vec::new();
102 for worker_rank in 0..num_workers {
103 let exe_path = env::current_exe().expect("Failed to get current exe");
104
105 let args: Vec<String> = env::args().collect();
106
107 let mut cmd = Command::new(exe_path);
108 cmd.args(&args[1..]);
109
110 let data = WorkerTransferData::Init {
111 id: BigCCharArray(*id.internal()),
112 worker_rank,
113 };
114
115 cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
116
117 cmd.stdout(std::process::Stdio::null());
118 cmd.stderr(std::process::Stdio::null());
119 cmd.stdin(std::process::Stdio::null());
120
121 children.push(cmd.spawn().expect("Failed to spawn process"));
122 }
123
124 let listener = ListenerOptions::new().name(name).create_sync()?;
125 let mut ready_count = 0;
126
127 while ready_count < num_workers {
128 let stream = listener.accept()?;
129 let mut reader = BufReader::new(stream);
130 let mut message = String::new();
131 reader.read_line(&mut message)?;
132 if message.trim() == "ready" {
133 ready_count += 1;
134 }
135 }
136 info!("All workers have received the ids!");
137
138 0
139 };
140
141 if use_multi_node {
142 if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
143 let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
144 info!("Head node managing {n_nodes} workers.");
145 let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
146 anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
147 };
148 info!("Head node initializing connection on {port}.");
149 let server = mistralrs_quant::Server::new(
150 &format!("0.0.0.0:{port}"),
151 n_nodes,
152 local_world_size,
153 )?;
154
155 server.broadcast_id(&id)?;
156 } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
157 info!("Worker node connecting to {addr}.");
158 let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
159
160 id = client.receive_id()?;
161 }
162 }
163
164 let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
165 let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
166 anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
167 };
168 let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
169 info!("Worker ID is {node_id}.");
170 (node_id + 1) * local_world_size
171 } else {
172 0
173 };
174
175 let comm = mistralrs_quant::Comm::from_device(
178 id,
179 device,
180 local_rank + rank_offset,
181 global_world_size,
182 )?;
183
184 let make_dummy_regexes = if loading_isq && from_uqff {
185 Some(std::sync::Arc::new(
187 if matches!(organization, IsqOrganization::MoeExpertsOnly) {
188 model.isq_layer_regexes_moqe(config)?
189 } else {
190 model.isq_layer_regexes(config)?
191 },
192 ))
193 } else {
194 None
195 };
196
197 let sharded_vb = unsafe {
198 ShardedSafeTensors::sharded(
199 paths.get_weight_filenames(),
200 dtype,
201 load_device,
202 make_dummy_regexes,
203 )?
204 };
205
206 info!("Loading all ranks.");
207 let mapper = DeviceMapSetting::Nccl {
209 nm_device: available_devices[0].clone(),
210 comm: Arc::new(comm),
211 }
212 .into_mapper(model.num_layers(config)?, device, None)?;
213
214 let sharded_vb = if !loading_isq {
215 sharded_vb.clone().set_device(device.clone())
216 } else {
217 sharded_vb.clone()
218 };
219
220 Ok((mapper, sharded_vb))
221}