mistralrs_core/
distributed.rs1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{RingConfig, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29 if cfg!(feature = "cuda") && !cfg!(feature = "ring") {
30 std::env::var(IS_DAEMON_FLAG).is_ok()
31 } else if cfg!(feature = "ring") {
32 !RingConfig::load().is_master_rank()
33 } else {
34 false
35 }
36}
37
38pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
39 use std::io::BufRead;
40 use std::io::BufReader;
41
42 std::thread::spawn(move || {
43 let rt = Runtime::new().unwrap();
44 rt.block_on(async move {
45 use interprocess::local_socket::traits::Stream;
46 use interprocess::local_socket::Stream as LocalStream;
47
48 loop {
49 let name = match ipc_name() {
50 Ok(name) => name,
51 Err(e) => {
52 tracing::error!("Failed to get IPC name in daemon: {e}");
53 continue;
54 }
55 };
56 if let Ok(stream) = LocalStream::connect(name) {
57 let mut reader = BufReader::new(stream);
58 let mut buf = String::new();
59 if let Err(e) = reader.read_line(&mut buf) {
60 tracing::error!("Failed to read line from IPC stream: {e}");
61 continue;
62 }
63 let mut req: Request = match serde_json::from_str(&buf) {
64 Ok(req) => req,
65 Err(e) => {
66 tracing::error!("Failed to parse request JSON: {e}");
67 continue;
68 }
69 };
70
71 req = match req {
72 Request::ReIsq(x) => Request::ReIsq(x),
73 Request::Terminate => Request::Terminate,
74 Request::Detokenize(mut x) => {
75 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
76 x.response = sender;
77 let req = Request::Detokenize(x);
78
79 if request_sender.send(req).await.is_err() {
80 tracing::error!("Daemon channel closed for Detokenize request");
81 continue;
82 }
83 match receiver.recv().await {
84 Some(resp) => {
85 if let Err(e) = resp {
86 tracing::error!("Detokenize response error: {e}");
87 }
88 }
89 None => tracing::error!("Detokenize response channel closed"),
90 }
91 continue;
92 }
93 Request::Tokenize(mut x) => {
94 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
95 x.response = sender;
96 let req = Request::Tokenize(x);
97
98 if request_sender.send(req).await.is_err() {
99 tracing::error!("Daemon channel closed for Tokenize request");
100 continue;
101 }
102 match receiver.recv().await {
103 Some(resp) => {
104 if let Err(e) = resp {
105 tracing::error!("Tokenize response error: {e}");
106 }
107 }
108 None => tracing::error!("Tokenize response channel closed"),
109 }
110 continue;
111 }
112 Request::Normal(mut x) => {
113 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
114 x.is_streaming = false;
115 x.response = sender;
116 let req = Request::Normal(x);
117
118 if request_sender.send(req).await.is_err() {
119 tracing::error!("Daemon channel closed for Normal request");
120 continue;
121 }
122 match receiver.recv().await {
123 Some(resp) => {
124 if let Err(e) = resp.as_result() {
125 tracing::error!("Normal response error: {e}");
126 }
127 }
128 None => tracing::error!("Normal response channel closed"),
129 }
130 continue;
131 }
132 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
133 };
134
135 if request_sender.send(req).await.is_err() {
136 tracing::error!("Daemon channel closed for request");
137 }
138 }
139 }
140 });
141 });
142}
143
144pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
145 use std::io::BufRead;
146 use std::io::BufReader;
147
148 let ring_config = RingConfig::load();
149
150 let master_ip = ring_config.master_ip();
151 let master_port = ring_config.master_port;
152 std::thread::spawn(move || {
153 let rt = Runtime::new().unwrap();
154 rt.block_on(async move {
155 loop {
156 if let Ok(stream) = TcpStream::connect(format!("{master_ip}:{master_port}")) {
157 let mut reader = BufReader::new(stream);
158 let mut buf = String::new();
159 reader.read_line(&mut buf).unwrap();
160 let mut req: Request = serde_json::from_str(&buf).unwrap();
161
162 req = match req {
163 Request::ReIsq(x) => Request::ReIsq(x),
164 Request::Terminate => Request::Terminate,
165 Request::Detokenize(mut x) => {
166 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
167 x.response = sender;
168 let req = Request::Detokenize(x);
169
170 request_sender.send(req).await.unwrap();
171 let resp = receiver.recv().await.unwrap();
172 resp.unwrap();
173 continue;
174 }
175 Request::Tokenize(mut x) => {
176 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
177 x.response = sender;
178 let req = Request::Tokenize(x);
179
180 request_sender.send(req).await.unwrap();
181 let resp = receiver.recv().await.unwrap();
182 resp.unwrap();
183 continue;
184 }
185 Request::Normal(mut x) => {
186 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
187 x.is_streaming = false;
188 x.response = sender;
189 let req = Request::Normal(x);
190
191 request_sender.send(req).await.unwrap();
192 let resp = receiver.recv().await.unwrap();
193 resp.as_result().unwrap();
194 continue;
195 }
196 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
197 };
198
199 request_sender.send(req).await.unwrap();
200 }
201 }
202 });
203 });
204}
205
206#[derive(Serialize, Deserialize, Debug)]
207#[serde(transparent)]
208pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
209
210#[derive(Serialize, Deserialize, Debug)]
211pub(crate) enum WorkerTransferData {
212 Init {
213 id: BigCCharArray,
214 worker_rank: usize,
215 },
216}
217
218pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
219 let printname = "mistralrs_daemon.sock";
220 Ok(printname.to_ns_name::<GenericNamespaced>()?)
221}
222
223#[allow(clippy::too_many_arguments)]
224pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
225 dtype: DType,
226 device: &Device,
227 available_devices: &[Device],
228 silent: bool,
229 config: &str,
230 loading_isq: bool,
231 from_uqff: bool,
232 organization: IsqOrganization,
233 model: &T,
234 paths: &dyn ModelPaths,
235) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
236 if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
237 tracing::warn!(
238 "Distributed support was not included in the build, be sure to build with `--features nccl`."
239 );
240 }
241
242 let local_world_size = available_devices.len();
245 let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
246 usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
247 } else {
248 mistralrs_quant::distributed::get_global_tp_size_from_devices()?
249 };
250
251 let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
252 if use_multi_node {
253 info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
254 }
255
256 if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
257 anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
258 }
259
260 info!("Local tensor parallel world size is {local_world_size}");
261 info!("Global tensor parallel world size is {global_world_size}");
262
263 let name = ipc_name()?;
265 let mut id;
266 let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
267 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
268 let WorkerTransferData::Init {
269 id: new_id,
270 worker_rank,
271 } = payload;
272 id = mistralrs_quant::Id::uninit(new_id.0);
273
274 let mut stream = LocalStream::connect(name)?;
275 stream.write_all(b"ready\n")?;
276 worker_rank + 1
277 } else if cfg!(feature = "ring") {
278 id = mistralrs_quant::Id::new();
279
280 let config = RingConfig::load();
281
282 config.rank
283 } else {
284 id = mistralrs_quant::Id::new();
285 let num_workers = mistralrs_quant::distributed::get_global_tp_size_from_devices()? - 1;
286 let mut children = Vec::new();
287 for worker_rank in 0..num_workers {
288 let exe_path = env::current_exe().expect("Failed to get current exe");
289
290 let args: Vec<String> = env::args().collect();
291
292 let mut cmd = Command::new(exe_path);
293 cmd.args(&args[1..]);
294
295 let data = WorkerTransferData::Init {
296 id: BigCCharArray(*id.internal()),
297 worker_rank,
298 };
299
300 cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
301
302 cmd.stdout(std::process::Stdio::null());
303 cmd.stderr(std::process::Stdio::null());
304 cmd.stdin(std::process::Stdio::null());
305
306 children.push(cmd.spawn().expect("Failed to spawn process"));
307 }
308
309 let listener = ListenerOptions::new().name(name).create_sync()?;
310 let mut ready_count = 0;
311
312 while ready_count < num_workers {
313 let stream = listener.accept()?;
314 let mut reader = BufReader::new(stream);
315 let mut message = String::new();
316 reader.read_line(&mut message)?;
317 if message.trim() == "ready" {
318 ready_count += 1;
319 }
320 }
321 info!("All workers have received the ids!");
322
323 0
324 };
325
326 if use_multi_node {
327 if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
328 let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
329 info!("Head node managing {n_nodes} workers.");
330 let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
331 anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
332 };
333 info!("Head node initializing connection on {port}.");
334 let server = mistralrs_quant::Server::new(
335 &format!("0.0.0.0:{port}"),
336 n_nodes,
337 local_world_size,
338 )?;
339
340 server.broadcast_id(&id)?;
341 } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
342 info!("Worker node connecting to {addr}.");
343 let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
344
345 id = client.receive_id()?;
346 }
347 }
348
349 let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
350 let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
351 anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
352 };
353 let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
354 info!("Worker ID is {node_id}.");
355 (node_id + 1) * local_world_size
356 } else {
357 0
358 };
359
360 let comm = mistralrs_quant::Comm::from_device(
363 id,
364 device,
365 local_rank + rank_offset,
366 global_world_size,
367 )?;
368
369 let make_dummy_regexes = if loading_isq && from_uqff {
370 Some(std::sync::Arc::new(
372 if matches!(organization, IsqOrganization::MoeExpertsOnly) {
373 model.isq_layer_regexes_moqe(config)?
374 } else {
375 model.isq_layer_regexes(config)?
376 },
377 ))
378 } else {
379 None
380 };
381
382 let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
383 paths.get_weight_filenames().to_vec(),
384 vec![],
385 Some(dtype),
386 &Device::Cpu,
387 vec![],
388 silent,
389 make_dummy_regexes,
390 |_| true,
391 Arc::new(|_| DeviceForLoadTensor::Base),
392 )?;
393
394 info!("Loading all ranks.");
395 let mapper = DeviceMapSetting::Nccl {
397 nm_device: available_devices[0].clone(),
398 comm: Arc::new(comm),
399 }
400 .into_mapper(model.num_layers(config)?, device, None)?;
401
402 let sharded_vb = if !loading_isq {
403 sharded_vb.clone().set_device(device.clone())
404 } else {
405 sharded_vb.clone()
406 };
407
408 Ok((mapper, sharded_vb))
409}