mistralrs_core/
distributed.rs1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{RingConfig, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29 #[cfg(feature = "cuda")]
30 {
31 std::env::var(IS_DAEMON_FLAG).is_ok()
32 }
33 #[cfg(feature = "ring")]
34 {
35 !RingConfig::load().is_master_rank()
36 }
37 #[cfg(not(any(feature = "cuda", feature = "ring")))]
38 false
39}
40
41pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
42 use std::io::BufRead;
43 use std::io::BufReader;
44
45 std::thread::spawn(move || {
46 let rt = Runtime::new().unwrap();
47 rt.block_on(async move {
48 use interprocess::local_socket::traits::Stream;
49 use interprocess::local_socket::Stream as LocalStream;
50
51 loop {
52 let name = ipc_name().unwrap();
53 if let Ok(stream) = LocalStream::connect(name) {
54 let mut reader = BufReader::new(stream);
55 let mut buf = String::new();
56 reader.read_line(&mut buf).unwrap();
57 let mut req: Request = serde_json::from_str(&buf).unwrap();
58
59 req = match req {
60 Request::ReIsq(x) => Request::ReIsq(x),
61 Request::Terminate => Request::Terminate,
62 Request::Detokenize(mut x) => {
63 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
64 x.response = sender;
65 let req = Request::Detokenize(x);
66
67 request_sender.send(req).await.unwrap();
68 let resp = receiver.recv().await.unwrap();
69 resp.unwrap();
70 continue;
71 }
72 Request::Tokenize(mut x) => {
73 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
74 x.response = sender;
75 let req = Request::Tokenize(x);
76
77 request_sender.send(req).await.unwrap();
78 let resp = receiver.recv().await.unwrap();
79 resp.unwrap();
80 continue;
81 }
82 Request::Normal(mut x) => {
83 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
84 x.is_streaming = false;
85 x.response = sender;
86 let req = Request::Normal(x);
87
88 request_sender.send(req).await.unwrap();
89 let resp = receiver.recv().await.unwrap();
90 resp.as_result().unwrap();
91 continue;
92 }
93 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
94 };
95
96 request_sender.send(req).await.unwrap();
97 }
98 }
99 });
100 });
101}
102
103pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
104 use std::io::BufRead;
105 use std::io::BufReader;
106
107 let ring_config = RingConfig::load();
108
109 let master_ip = ring_config.master_ip();
110 let master_port = ring_config.master_port;
111 std::thread::spawn(move || {
112 let rt = Runtime::new().unwrap();
113 rt.block_on(async move {
114 loop {
115 if let Ok(stream) = TcpStream::connect(format!("{}:{}", master_ip, master_port)) {
116 let mut reader = BufReader::new(stream);
117 let mut buf = String::new();
118 reader.read_line(&mut buf).unwrap();
119 let mut req: Request = serde_json::from_str(&buf).unwrap();
120
121 req = match req {
122 Request::ReIsq(x) => Request::ReIsq(x),
123 Request::Terminate => Request::Terminate,
124 Request::Detokenize(mut x) => {
125 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
126 x.response = sender;
127 let req = Request::Detokenize(x);
128
129 request_sender.send(req).await.unwrap();
130 let resp = receiver.recv().await.unwrap();
131 resp.unwrap();
132 continue;
133 }
134 Request::Tokenize(mut x) => {
135 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
136 x.response = sender;
137 let req = Request::Tokenize(x);
138
139 request_sender.send(req).await.unwrap();
140 let resp = receiver.recv().await.unwrap();
141 resp.unwrap();
142 continue;
143 }
144 Request::Normal(mut x) => {
145 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
146 x.is_streaming = false;
147 x.response = sender;
148 let req = Request::Normal(x);
149
150 request_sender.send(req).await.unwrap();
151 let resp = receiver.recv().await.unwrap();
152 resp.as_result().unwrap();
153 continue;
154 }
155 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
156 };
157
158 request_sender.send(req).await.unwrap();
159 }
160 }
161 });
162 });
163}
164
165#[derive(Serialize, Deserialize, Debug)]
166#[serde(transparent)]
167pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
168
169#[derive(Serialize, Deserialize, Debug)]
170pub(crate) enum WorkerTransferData {
171 Init {
172 id: BigCCharArray,
173 worker_rank: usize,
174 },
175}
176
177pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
178 let printname = "mistralrs_daemon.sock";
179 Ok(printname.to_ns_name::<GenericNamespaced>()?)
180}
181
182#[allow(clippy::too_many_arguments)]
183pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
184 dtype: DType,
185 device: &Device,
186 available_devices: &[Device],
187 silent: bool,
188 config: &str,
189 loading_isq: bool,
190 from_uqff: bool,
191 organization: IsqOrganization,
192 model: &T,
193 paths: &dyn ModelPaths,
194) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
195 if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
196 tracing::warn!(
197 "Distributed support was not included in the build, be sure to build with `--features nccl`."
198 );
199 }
200
201 let local_world_size = available_devices.len();
204 let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
205 usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
206 } else {
207 mistralrs_quant::distributed::get_global_tp_size_from_devices()?
208 };
209
210 let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
211 if use_multi_node {
212 info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
213 }
214
215 if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
216 anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
217 }
218
219 info!("Local tensor parallel world size is {local_world_size}");
220 info!("Global tensor parallel world size is {global_world_size}");
221
222 let name = ipc_name()?;
224 let mut id;
225 let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
226 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
227 let WorkerTransferData::Init {
228 id: new_id,
229 worker_rank,
230 } = payload;
231 id = mistralrs_quant::Id::uninit(new_id.0);
232
233 let mut stream = LocalStream::connect(name)?;
234 stream.write_all(b"ready\n")?;
235 worker_rank + 1
236 } else if cfg!(feature = "ring") {
237 id = mistralrs_quant::Id::new();
238
239 let config = RingConfig::load();
240
241 config.rank
242 } else {
243 id = mistralrs_quant::Id::new();
244 let num_workers = mistralrs_quant::distributed::get_global_tp_size_from_devices()? - 1;
245 let mut children = Vec::new();
246 for worker_rank in 0..num_workers {
247 let exe_path = env::current_exe().expect("Failed to get current exe");
248
249 let args: Vec<String> = env::args().collect();
250
251 let mut cmd = Command::new(exe_path);
252 cmd.args(&args[1..]);
253
254 let data = WorkerTransferData::Init {
255 id: BigCCharArray(*id.internal()),
256 worker_rank,
257 };
258
259 cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
260
261 cmd.stdout(std::process::Stdio::null());
262 cmd.stderr(std::process::Stdio::null());
263 cmd.stdin(std::process::Stdio::null());
264
265 children.push(cmd.spawn().expect("Failed to spawn process"));
266 }
267
268 let listener = ListenerOptions::new().name(name).create_sync()?;
269 let mut ready_count = 0;
270
271 while ready_count < num_workers {
272 let stream = listener.accept()?;
273 let mut reader = BufReader::new(stream);
274 let mut message = String::new();
275 reader.read_line(&mut message)?;
276 if message.trim() == "ready" {
277 ready_count += 1;
278 }
279 }
280 info!("All workers have received the ids!");
281
282 0
283 };
284
285 if use_multi_node {
286 if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
287 let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
288 info!("Head node managing {n_nodes} workers.");
289 let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
290 anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
291 };
292 info!("Head node initializing connection on {port}.");
293 let server = mistralrs_quant::Server::new(
294 &format!("0.0.0.0:{port}"),
295 n_nodes,
296 local_world_size,
297 )?;
298
299 server.broadcast_id(&id)?;
300 } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
301 info!("Worker node connecting to {addr}.");
302 let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
303
304 id = client.receive_id()?;
305 }
306 }
307
308 let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
309 let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
310 anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
311 };
312 let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
313 info!("Worker ID is {node_id}.");
314 (node_id + 1) * local_world_size
315 } else {
316 0
317 };
318
319 let comm = mistralrs_quant::Comm::from_device(
322 id,
323 device,
324 local_rank + rank_offset,
325 global_world_size,
326 )?;
327
328 let make_dummy_regexes = if loading_isq && from_uqff {
329 Some(std::sync::Arc::new(
331 if matches!(organization, IsqOrganization::MoeExpertsOnly) {
332 model.isq_layer_regexes_moqe(config)?
333 } else {
334 model.isq_layer_regexes(config)?
335 },
336 ))
337 } else {
338 None
339 };
340
341 let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
342 paths.get_weight_filenames().to_vec(),
343 vec![],
344 Some(dtype),
345 &Device::Cpu,
346 vec![],
347 silent,
348 make_dummy_regexes,
349 |_| true,
350 Arc::new(|_| DeviceForLoadTensor::Base),
351 )?;
352
353 info!("Loading all ranks.");
354 let mapper = DeviceMapSetting::Nccl {
356 nm_device: available_devices[0].clone(),
357 comm: Arc::new(comm),
358 }
359 .into_mapper(model.num_layers(config)?, device, None)?;
360
361 let sharded_vb = if !loading_isq {
362 sharded_vb.clone().set_device(device.clone())
363 } else {
364 sharded_vb.clone()
365 };
366
367 Ok((mapper, sharded_vb))
368}