mistralrs_core/
distributed.rs1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{RingConfig, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29 if cfg!(feature = "cuda") && !cfg!(feature = "ring") {
30 std::env::var(IS_DAEMON_FLAG).is_ok()
31 } else if cfg!(feature = "ring") {
32 !RingConfig::load().is_master_rank()
33 } else {
34 false
35 }
36}
37
38pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
39 use std::io::BufRead;
40 use std::io::BufReader;
41
42 std::thread::spawn(move || {
43 let rt = Runtime::new().unwrap();
44 rt.block_on(async move {
45 use interprocess::local_socket::traits::Stream;
46 use interprocess::local_socket::Stream as LocalStream;
47
48 loop {
49 let name = ipc_name().unwrap();
50 if let Ok(stream) = LocalStream::connect(name) {
51 let mut reader = BufReader::new(stream);
52 let mut buf = String::new();
53 reader.read_line(&mut buf).unwrap();
54 let mut req: Request = serde_json::from_str(&buf).unwrap();
55
56 req = match req {
57 Request::ReIsq(x) => Request::ReIsq(x),
58 Request::Terminate => Request::Terminate,
59 Request::Detokenize(mut x) => {
60 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
61 x.response = sender;
62 let req = Request::Detokenize(x);
63
64 request_sender.send(req).await.unwrap();
65 let resp = receiver.recv().await.unwrap();
66 resp.unwrap();
67 continue;
68 }
69 Request::Tokenize(mut x) => {
70 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
71 x.response = sender;
72 let req = Request::Tokenize(x);
73
74 request_sender.send(req).await.unwrap();
75 let resp = receiver.recv().await.unwrap();
76 resp.unwrap();
77 continue;
78 }
79 Request::Normal(mut x) => {
80 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
81 x.is_streaming = false;
82 x.response = sender;
83 let req = Request::Normal(x);
84
85 request_sender.send(req).await.unwrap();
86 let resp = receiver.recv().await.unwrap();
87 resp.as_result().unwrap();
88 continue;
89 }
90 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
91 };
92
93 request_sender.send(req).await.unwrap();
94 }
95 }
96 });
97 });
98}
99
100pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
101 use std::io::BufRead;
102 use std::io::BufReader;
103
104 let ring_config = RingConfig::load();
105
106 let master_ip = ring_config.master_ip();
107 let master_port = ring_config.master_port;
108 std::thread::spawn(move || {
109 let rt = Runtime::new().unwrap();
110 rt.block_on(async move {
111 loop {
112 if let Ok(stream) = TcpStream::connect(format!("{master_ip}:{master_port}")) {
113 let mut reader = BufReader::new(stream);
114 let mut buf = String::new();
115 reader.read_line(&mut buf).unwrap();
116 let mut req: Request = serde_json::from_str(&buf).unwrap();
117
118 req = match req {
119 Request::ReIsq(x) => Request::ReIsq(x),
120 Request::Terminate => Request::Terminate,
121 Request::Detokenize(mut x) => {
122 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
123 x.response = sender;
124 let req = Request::Detokenize(x);
125
126 request_sender.send(req).await.unwrap();
127 let resp = receiver.recv().await.unwrap();
128 resp.unwrap();
129 continue;
130 }
131 Request::Tokenize(mut x) => {
132 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
133 x.response = sender;
134 let req = Request::Tokenize(x);
135
136 request_sender.send(req).await.unwrap();
137 let resp = receiver.recv().await.unwrap();
138 resp.unwrap();
139 continue;
140 }
141 Request::Normal(mut x) => {
142 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
143 x.is_streaming = false;
144 x.response = sender;
145 let req = Request::Normal(x);
146
147 request_sender.send(req).await.unwrap();
148 let resp = receiver.recv().await.unwrap();
149 resp.as_result().unwrap();
150 continue;
151 }
152 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
153 };
154
155 request_sender.send(req).await.unwrap();
156 }
157 }
158 });
159 });
160}
161
162#[derive(Serialize, Deserialize, Debug)]
163#[serde(transparent)]
164pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
165
166#[derive(Serialize, Deserialize, Debug)]
167pub(crate) enum WorkerTransferData {
168 Init {
169 id: BigCCharArray,
170 worker_rank: usize,
171 },
172}
173
174pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
175 let printname = "mistralrs_daemon.sock";
176 Ok(printname.to_ns_name::<GenericNamespaced>()?)
177}
178
179#[allow(clippy::too_many_arguments)]
180pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
181 dtype: DType,
182 device: &Device,
183 available_devices: &[Device],
184 silent: bool,
185 config: &str,
186 loading_isq: bool,
187 from_uqff: bool,
188 organization: IsqOrganization,
189 model: &T,
190 paths: &dyn ModelPaths,
191) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
192 if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
193 tracing::warn!(
194 "Distributed support was not included in the build, be sure to build with `--features nccl`."
195 );
196 }
197
198 let local_world_size = available_devices.len();
201 let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
202 usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
203 } else {
204 mistralrs_quant::distributed::get_global_tp_size_from_devices()?
205 };
206
207 let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
208 if use_multi_node {
209 info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
210 }
211
212 if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
213 anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
214 }
215
216 info!("Local tensor parallel world size is {local_world_size}");
217 info!("Global tensor parallel world size is {global_world_size}");
218
219 let name = ipc_name()?;
221 let mut id;
222 let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
223 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
224 let WorkerTransferData::Init {
225 id: new_id,
226 worker_rank,
227 } = payload;
228 id = mistralrs_quant::Id::uninit(new_id.0);
229
230 let mut stream = LocalStream::connect(name)?;
231 stream.write_all(b"ready\n")?;
232 worker_rank + 1
233 } else if cfg!(feature = "ring") {
234 id = mistralrs_quant::Id::new();
235
236 let config = RingConfig::load();
237
238 config.rank
239 } else {
240 id = mistralrs_quant::Id::new();
241 let num_workers = mistralrs_quant::distributed::get_global_tp_size_from_devices()? - 1;
242 let mut children = Vec::new();
243 for worker_rank in 0..num_workers {
244 let exe_path = env::current_exe().expect("Failed to get current exe");
245
246 let args: Vec<String> = env::args().collect();
247
248 let mut cmd = Command::new(exe_path);
249 cmd.args(&args[1..]);
250
251 let data = WorkerTransferData::Init {
252 id: BigCCharArray(*id.internal()),
253 worker_rank,
254 };
255
256 cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
257
258 cmd.stdout(std::process::Stdio::null());
259 cmd.stderr(std::process::Stdio::null());
260 cmd.stdin(std::process::Stdio::null());
261
262 children.push(cmd.spawn().expect("Failed to spawn process"));
263 }
264
265 let listener = ListenerOptions::new().name(name).create_sync()?;
266 let mut ready_count = 0;
267
268 while ready_count < num_workers {
269 let stream = listener.accept()?;
270 let mut reader = BufReader::new(stream);
271 let mut message = String::new();
272 reader.read_line(&mut message)?;
273 if message.trim() == "ready" {
274 ready_count += 1;
275 }
276 }
277 info!("All workers have received the ids!");
278
279 0
280 };
281
282 if use_multi_node {
283 if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
284 let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
285 info!("Head node managing {n_nodes} workers.");
286 let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
287 anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
288 };
289 info!("Head node initializing connection on {port}.");
290 let server = mistralrs_quant::Server::new(
291 &format!("0.0.0.0:{port}"),
292 n_nodes,
293 local_world_size,
294 )?;
295
296 server.broadcast_id(&id)?;
297 } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
298 info!("Worker node connecting to {addr}.");
299 let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
300
301 id = client.receive_id()?;
302 }
303 }
304
305 let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
306 let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
307 anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
308 };
309 let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
310 info!("Worker ID is {node_id}.");
311 (node_id + 1) * local_world_size
312 } else {
313 0
314 };
315
316 let comm = mistralrs_quant::Comm::from_device(
319 id,
320 device,
321 local_rank + rank_offset,
322 global_world_size,
323 )?;
324
325 let make_dummy_regexes = if loading_isq && from_uqff {
326 Some(std::sync::Arc::new(
328 if matches!(organization, IsqOrganization::MoeExpertsOnly) {
329 model.isq_layer_regexes_moqe(config)?
330 } else {
331 model.isq_layer_regexes(config)?
332 },
333 ))
334 } else {
335 None
336 };
337
338 let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
339 paths.get_weight_filenames().to_vec(),
340 vec![],
341 Some(dtype),
342 &Device::Cpu,
343 vec![],
344 silent,
345 make_dummy_regexes,
346 |_| true,
347 Arc::new(|_| DeviceForLoadTensor::Base),
348 )?;
349
350 info!("Loading all ranks.");
351 let mapper = DeviceMapSetting::Nccl {
353 nm_device: available_devices[0].clone(),
354 comm: Arc::new(comm),
355 }
356 .into_mapper(model.num_layers(config)?, device, None)?;
357
358 let sharded_vb = if !loading_isq {
359 sharded_vb.clone().set_device(device.clone())
360 } else {
361 sharded_vb.clone()
362 };
363
364 Ok((mapper, sharded_vb))
365}