mistralrs_core/
distributed.rs1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{RingConfig, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29 if cfg!(feature = "cuda") && !cfg!(feature = "ring") {
30 std::env::var(IS_DAEMON_FLAG).is_ok()
31 } else if cfg!(feature = "ring") {
32 !RingConfig::load().is_master_rank()
33 } else {
34 false
35 }
36}
37
38pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
39 use std::io::BufRead;
40 use std::io::BufReader;
41
42 std::thread::spawn(move || {
43 let rt = Runtime::new().unwrap();
44 rt.block_on(async move {
45 use interprocess::local_socket::traits::Stream;
46 use interprocess::local_socket::Stream as LocalStream;
47
48 loop {
49 let name = match ipc_name() {
50 Ok(name) => name,
51 Err(e) => {
52 tracing::error!("Failed to get IPC name in daemon: {e}");
53 continue;
54 }
55 };
56 if let Ok(stream) = LocalStream::connect(name) {
57 let mut reader = BufReader::new(stream);
58 let mut buf = String::new();
59 if let Err(e) = reader.read_line(&mut buf) {
60 tracing::error!("Failed to read line from IPC stream: {e}");
61 continue;
62 }
63 let mut req: Request = match serde_json::from_str(&buf) {
64 Ok(req) => req,
65 Err(e) => {
66 tracing::error!("Failed to parse request JSON: {e}");
67 continue;
68 }
69 };
70
71 req = match req {
72 Request::ReIsq(x) => Request::ReIsq(x),
73 Request::Terminate => Request::Terminate,
74 Request::Detokenize(mut x) => {
75 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
76 x.response = sender;
77 let req = Request::Detokenize(x);
78
79 if request_sender.send(req).await.is_err() {
80 tracing::error!("Daemon channel closed for Detokenize request");
81 continue;
82 }
83 match receiver.recv().await {
84 Some(resp) => {
85 if let Err(e) = resp {
86 tracing::error!("Detokenize response error: {e}");
87 }
88 }
89 None => tracing::error!("Detokenize response channel closed"),
90 }
91 continue;
92 }
93 Request::Tokenize(mut x) => {
94 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
95 x.response = sender;
96 let req = Request::Tokenize(x);
97
98 if request_sender.send(req).await.is_err() {
99 tracing::error!("Daemon channel closed for Tokenize request");
100 continue;
101 }
102 match receiver.recv().await {
103 Some(resp) => {
104 if let Err(e) = resp {
105 tracing::error!("Tokenize response error: {e}");
106 }
107 }
108 None => tracing::error!("Tokenize response channel closed"),
109 }
110 continue;
111 }
112 Request::Normal(mut x) => {
113 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
114 x.is_streaming = false;
115 x.response = sender;
116 let req = Request::Normal(x);
117
118 if request_sender.send(req).await.is_err() {
119 tracing::error!("Daemon channel closed for Normal request");
120 continue;
121 }
122 match receiver.recv().await {
123 Some(resp) => {
124 if let Err(e) = resp.as_result() {
125 tracing::error!("Normal response error: {e}");
126 }
127 }
128 None => tracing::error!("Normal response channel closed"),
129 }
130 continue;
131 }
132 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
133 };
134
135 if request_sender.send(req).await.is_err() {
136 tracing::error!("Daemon channel closed for request");
137 }
138 }
139 }
140 });
141 });
142}
143
144pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
145 use std::io::BufRead;
146 use std::io::BufReader;
147
148 let ring_config = RingConfig::load();
149
150 let master_ip = ring_config.master_ip();
151 let master_port = ring_config.master_port;
152 std::thread::spawn(move || {
153 let rt = Runtime::new().unwrap();
154 rt.block_on(async move {
155 loop {
156 if let Ok(stream) = TcpStream::connect(format!("{master_ip}:{master_port}")) {
157 let mut reader = BufReader::new(stream);
158 let mut buf = String::new();
159 reader.read_line(&mut buf).unwrap();
160 let mut req: Request = serde_json::from_str(&buf).unwrap();
161
162 req = match req {
163 Request::ReIsq(x) => Request::ReIsq(x),
164 Request::Terminate => Request::Terminate,
165 Request::Detokenize(mut x) => {
166 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
167 x.response = sender;
168 let req = Request::Detokenize(x);
169
170 request_sender.send(req).await.unwrap();
171 let resp = receiver.recv().await.unwrap();
172 resp.unwrap();
173 continue;
174 }
175 Request::Tokenize(mut x) => {
176 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
177 x.response = sender;
178 let req = Request::Tokenize(x);
179
180 request_sender.send(req).await.unwrap();
181 let resp = receiver.recv().await.unwrap();
182 resp.unwrap();
183 continue;
184 }
185 Request::Normal(mut x) => {
186 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
187 x.is_streaming = false;
188 x.response = sender;
189 let req = Request::Normal(x);
190
191 request_sender.send(req).await.unwrap();
192 let resp = receiver.recv().await.unwrap();
193 resp.as_result().unwrap();
194 continue;
195 }
196 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
197 };
198
199 request_sender.send(req).await.unwrap();
200 }
201 }
202 });
203 });
204}
205
206#[derive(Serialize, Deserialize, Debug)]
207#[serde(transparent)]
208pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
209
210#[derive(Serialize, Deserialize, Debug)]
211pub(crate) enum WorkerTransferData {
212 Init {
213 id: BigCCharArray,
214 worker_rank: usize,
215 },
216}
217
218pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
219 let printname = "mistralrs_daemon.sock";
220 Ok(printname.to_ns_name::<GenericNamespaced>()?)
221}
222
223#[allow(clippy::too_many_arguments)]
224pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
225 dtype: DType,
226 device: &Device,
227 available_devices: &[Device],
228 silent: bool,
229 config: &str,
230 loading_isq: bool,
231 from_uqff: bool,
232 organization: IsqOrganization,
233 model: &T,
234 paths: &dyn ModelPaths,
235) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
236 if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
237 tracing::warn!(
238 "Distributed support was not included in the build, be sure to build with `--features nccl`."
239 );
240 }
241
242 let local_world_size = available_devices.len();
245 let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
246 usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
247 } else {
248 std::cmp::max(
250 mistralrs_quant::distributed::get_global_tp_size_from_devices()?,
251 local_world_size,
252 )
253 };
254
255 let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
256 if use_multi_node {
257 info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
258 }
259
260 if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
261 anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
262 }
263
264 info!("Local tensor parallel world size is {local_world_size}");
265 info!("Global tensor parallel world size is {global_world_size}");
266
267 let name = ipc_name()?;
269 let mut id;
270 let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
271 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
272 let WorkerTransferData::Init {
273 id: new_id,
274 worker_rank,
275 } = payload;
276 id = mistralrs_quant::Id::uninit(new_id.0);
277
278 let mut stream = LocalStream::connect(name)?;
279 stream.write_all(b"ready\n")?;
280 worker_rank + 1
281 } else if cfg!(feature = "ring") {
282 id = mistralrs_quant::Id::new();
283
284 let config = RingConfig::load();
285
286 config.rank
287 } else {
288 id = mistralrs_quant::Id::new();
289 let num_ranks = mistralrs_quant::distributed::get_global_tp_size_from_devices()?;
290 let num_workers = num_ranks - 1;
291 let mut children = Vec::new();
292 for worker_rank in 0..num_workers {
293 let exe_path = env::current_exe().expect("Failed to get current exe");
294
295 let args: Vec<String> = env::args().collect();
296
297 let mut cmd = Command::new(exe_path);
298 cmd.args(&args[1..]);
299
300 let data = WorkerTransferData::Init {
301 id: BigCCharArray(*id.internal()),
302 worker_rank,
303 };
304
305 cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
306
307 cmd.stdout(std::process::Stdio::null());
308 cmd.stderr(std::process::Stdio::null());
309 cmd.stdin(std::process::Stdio::null());
310
311 children.push(cmd.spawn().expect("Failed to spawn process"));
312 }
313
314 let listener = ListenerOptions::new().name(name).create_sync()?;
315 let mut ready_count = 0;
316
317 while ready_count < num_workers {
318 let stream = listener.accept()?;
319 let mut reader = BufReader::new(stream);
320 let mut message = String::new();
321 reader.read_line(&mut message)?;
322 if message.trim() == "ready" {
323 ready_count += 1;
324 }
325 }
326 info!("All workers have received the ids!");
327
328 0
329 };
330
331 if use_multi_node {
332 if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
333 let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
334 info!("Head node managing {n_nodes} workers.");
335 let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
336 anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
337 };
338 info!("Head node initializing connection on {port}.");
339 let server = mistralrs_quant::Server::new(
340 &format!("0.0.0.0:{port}"),
341 n_nodes,
342 local_world_size,
343 )?;
344
345 server.broadcast_id(&id)?;
346 } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
347 info!("Worker node connecting to {addr}.");
348 let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
349
350 id = client.receive_id()?;
351 }
352 }
353
354 let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
355 let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
356 anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
357 };
358 let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
359 info!("Worker ID is {node_id}.");
360 (node_id + 1) * local_world_size
361 } else {
362 0
363 };
364
365 let comm = mistralrs_quant::Comm::from_device(
368 id,
369 device,
370 local_rank + rank_offset,
371 global_world_size,
372 )?;
373
374 let make_dummy_regexes = if loading_isq && from_uqff {
375 Some(std::sync::Arc::new(
377 if matches!(organization, IsqOrganization::MoeExpertsOnly) {
378 model.isq_layer_regexes_moqe(config)?
379 } else {
380 model.isq_layer_regexes(config)?
381 },
382 ))
383 } else {
384 None
385 };
386
387 let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
388 paths.get_weight_filenames().to_vec(),
389 vec![],
390 Some(dtype),
391 &Device::Cpu,
392 vec![],
393 silent,
394 make_dummy_regexes,
395 |_| true,
396 Arc::new(|_| DeviceForLoadTensor::Base),
397 )?;
398
399 info!("Loading all ranks.");
400 let mapper = DeviceMapSetting::Nccl {
402 nm_device: available_devices[0].clone(),
403 comm: Arc::new(comm),
404 }
405 .into_mapper(model.num_layers(config)?, device, None)?;
406
407 let sharded_vb = if !loading_isq {
408 sharded_vb.clone().set_device(device.clone())
409 } else {
410 sharded_vb.clone()
411 };
412
413 Ok((mapper, sharded_vb))
414}