mistralrs_core/
distributed.rs

1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{RingConfig, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29    #[cfg(feature = "cuda")]
30    {
31        std::env::var(IS_DAEMON_FLAG).is_ok()
32    }
33    #[cfg(feature = "ring")]
34    {
35        !RingConfig::load().is_master_rank()
36    }
37    #[cfg(not(any(feature = "cuda", feature = "ring")))]
38    false
39}
40
41pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
42    use std::io::BufRead;
43    use std::io::BufReader;
44
45    std::thread::spawn(move || {
46        let rt = Runtime::new().unwrap();
47        rt.block_on(async move {
48            use interprocess::local_socket::traits::Stream;
49            use interprocess::local_socket::Stream as LocalStream;
50
51            loop {
52                let name = ipc_name().unwrap();
53                if let Ok(stream) = LocalStream::connect(name) {
54                    let mut reader = BufReader::new(stream);
55                    let mut buf = String::new();
56                    reader.read_line(&mut buf).unwrap();
57                    let mut req: Request = serde_json::from_str(&buf).unwrap();
58
59                    req = match req {
60                        Request::ReIsq(x) => Request::ReIsq(x),
61                        Request::Terminate => Request::Terminate,
62                        Request::Detokenize(mut x) => {
63                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
64                            x.response = sender;
65                            let req = Request::Detokenize(x);
66
67                            request_sender.send(req).await.unwrap();
68                            let resp = receiver.recv().await.unwrap();
69                            resp.unwrap();
70                            continue;
71                        }
72                        Request::Tokenize(mut x) => {
73                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
74                            x.response = sender;
75                            let req = Request::Tokenize(x);
76
77                            request_sender.send(req).await.unwrap();
78                            let resp = receiver.recv().await.unwrap();
79                            resp.unwrap();
80                            continue;
81                        }
82                        Request::Normal(mut x) => {
83                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
84                            x.is_streaming = false;
85                            x.response = sender;
86                            let req = Request::Normal(x);
87
88                            request_sender.send(req).await.unwrap();
89                            let resp = receiver.recv().await.unwrap();
90                            resp.as_result().unwrap();
91                            continue;
92                        }
93                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
94                    };
95
96                    request_sender.send(req).await.unwrap();
97                }
98            }
99        });
100    });
101}
102
103pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
104    use std::io::BufRead;
105    use std::io::BufReader;
106
107    let ring_config = RingConfig::load();
108
109    let master_ip = ring_config.master_ip();
110    let master_port = ring_config.master_port;
111    std::thread::spawn(move || {
112        let rt = Runtime::new().unwrap();
113        rt.block_on(async move {
114            loop {
115                if let Ok(stream) = TcpStream::connect(format!("{}:{}", master_ip, master_port)) {
116                    let mut reader = BufReader::new(stream);
117                    let mut buf = String::new();
118                    reader.read_line(&mut buf).unwrap();
119                    let mut req: Request = serde_json::from_str(&buf).unwrap();
120
121                    req = match req {
122                        Request::ReIsq(x) => Request::ReIsq(x),
123                        Request::Terminate => Request::Terminate,
124                        Request::Detokenize(mut x) => {
125                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
126                            x.response = sender;
127                            let req = Request::Detokenize(x);
128
129                            request_sender.send(req).await.unwrap();
130                            let resp = receiver.recv().await.unwrap();
131                            resp.unwrap();
132                            continue;
133                        }
134                        Request::Tokenize(mut x) => {
135                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
136                            x.response = sender;
137                            let req = Request::Tokenize(x);
138
139                            request_sender.send(req).await.unwrap();
140                            let resp = receiver.recv().await.unwrap();
141                            resp.unwrap();
142                            continue;
143                        }
144                        Request::Normal(mut x) => {
145                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
146                            x.is_streaming = false;
147                            x.response = sender;
148                            let req = Request::Normal(x);
149
150                            request_sender.send(req).await.unwrap();
151                            let resp = receiver.recv().await.unwrap();
152                            resp.as_result().unwrap();
153                            continue;
154                        }
155                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
156                    };
157
158                    request_sender.send(req).await.unwrap();
159                }
160            }
161        });
162    });
163}
164
165#[derive(Serialize, Deserialize, Debug)]
166#[serde(transparent)]
167pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
168
169#[derive(Serialize, Deserialize, Debug)]
170pub(crate) enum WorkerTransferData {
171    Init {
172        id: BigCCharArray,
173        worker_rank: usize,
174    },
175}
176
177pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
178    let printname = "mistralrs_daemon.sock";
179    Ok(printname.to_ns_name::<GenericNamespaced>()?)
180}
181
182#[allow(clippy::too_many_arguments)]
183pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
184    dtype: DType,
185    device: &Device,
186    available_devices: &[Device],
187    silent: bool,
188    config: &str,
189    loading_isq: bool,
190    from_uqff: bool,
191    organization: IsqOrganization,
192    model: &T,
193    paths: &dyn ModelPaths,
194) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
195    if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
196        tracing::warn!(
197            "Distributed support was not included in the build, be sure to build with `--features nccl`."
198        );
199    }
200
201    // NCCL case!
202
203    let local_world_size = available_devices.len();
204    let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
205        usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
206    } else {
207        mistralrs_quant::distributed::get_global_tp_size_from_devices()?
208    };
209
210    let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
211    if use_multi_node {
212        info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
213    }
214
215    if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
216        anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
217    }
218
219    info!("Local tensor parallel world size is {local_world_size}");
220    info!("Global tensor parallel world size is {global_world_size}");
221
222    // TP uses parallel pipelines.
223    let name = ipc_name()?;
224    let mut id;
225    let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
226        let payload: WorkerTransferData = serde_json::from_str(&payload)?;
227        let WorkerTransferData::Init {
228            id: new_id,
229            worker_rank,
230        } = payload;
231        id = mistralrs_quant::Id::uninit(new_id.0);
232
233        let mut stream = LocalStream::connect(name)?;
234        stream.write_all(b"ready\n")?;
235        worker_rank + 1
236    } else if cfg!(feature = "ring") {
237        id = mistralrs_quant::Id::new();
238
239        let config = RingConfig::load();
240
241        config.rank
242    } else {
243        id = mistralrs_quant::Id::new();
244        let num_workers = mistralrs_quant::distributed::get_global_tp_size_from_devices()? - 1;
245        let mut children = Vec::new();
246        for worker_rank in 0..num_workers {
247            let exe_path = env::current_exe().expect("Failed to get current exe");
248
249            let args: Vec<String> = env::args().collect();
250
251            let mut cmd = Command::new(exe_path);
252            cmd.args(&args[1..]);
253
254            let data = WorkerTransferData::Init {
255                id: BigCCharArray(*id.internal()),
256                worker_rank,
257            };
258
259            cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
260
261            cmd.stdout(std::process::Stdio::null());
262            cmd.stderr(std::process::Stdio::null());
263            cmd.stdin(std::process::Stdio::null());
264
265            children.push(cmd.spawn().expect("Failed to spawn process"));
266        }
267
268        let listener = ListenerOptions::new().name(name).create_sync()?;
269        let mut ready_count = 0;
270
271        while ready_count < num_workers {
272            let stream = listener.accept()?;
273            let mut reader = BufReader::new(stream);
274            let mut message = String::new();
275            reader.read_line(&mut message)?;
276            if message.trim() == "ready" {
277                ready_count += 1;
278            }
279        }
280        info!("All workers have received the ids!");
281
282        0
283    };
284
285    if use_multi_node {
286        if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
287            let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
288            info!("Head node managing {n_nodes} workers.");
289            let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
290                anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
291            };
292            info!("Head node initializing connection on {port}.");
293            let server = mistralrs_quant::Server::new(
294                &format!("0.0.0.0:{port}"),
295                n_nodes,
296                local_world_size,
297            )?;
298
299            server.broadcast_id(&id)?;
300        } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
301            info!("Worker node connecting to {addr}.");
302            let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
303
304            id = client.receive_id()?;
305        }
306    }
307
308    let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
309        let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
310            anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
311        };
312        let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
313        info!("Worker ID is {node_id}.");
314        (node_id + 1) * local_world_size
315    } else {
316        0
317    };
318
319    // They each block on each other
320    // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html?ncclcomminitrank#ncclcomminitrank
321    let comm = mistralrs_quant::Comm::from_device(
322        id,
323        device,
324        local_rank + rank_offset,
325        global_world_size,
326    )?;
327
328    let make_dummy_regexes = if loading_isq && from_uqff {
329        // Dummy weights for the layers which will be overwritten...
330        Some(std::sync::Arc::new(
331            if matches!(organization, IsqOrganization::MoeExpertsOnly) {
332                model.isq_layer_regexes_moqe(config)?
333            } else {
334                model.isq_layer_regexes(config)?
335            },
336        ))
337    } else {
338        None
339    };
340
341    let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
342        paths.get_weight_filenames().to_vec(),
343        vec![],
344        Some(dtype),
345        &Device::Cpu,
346        vec![],
347        silent,
348        make_dummy_regexes,
349        |_| true,
350        Arc::new(|_| DeviceForLoadTensor::Base),
351    )?;
352
353    info!("Loading all ranks.");
354    // The mapper is specific to this pipeline
355    let mapper = DeviceMapSetting::Nccl {
356        nm_device: available_devices[0].clone(),
357        comm: Arc::new(comm),
358    }
359    .into_mapper(model.num_layers(config)?, device, None)?;
360
361    let sharded_vb = if !loading_isq {
362        sharded_vb.clone().set_device(device.clone())
363    } else {
364        sharded_vb.clone()
365    };
366
367    Ok((mapper, sharded_vb))
368}