mistralrs_core/
distributed.rs

1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{RingConfig, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29    if cfg!(feature = "cuda") && !cfg!(feature = "ring") {
30        std::env::var(IS_DAEMON_FLAG).is_ok()
31    } else if cfg!(feature = "ring") {
32        !RingConfig::load().is_master_rank()
33    } else {
34        false
35    }
36}
37
38pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
39    use std::io::BufRead;
40    use std::io::BufReader;
41
42    std::thread::spawn(move || {
43        let rt = Runtime::new().unwrap();
44        rt.block_on(async move {
45            use interprocess::local_socket::traits::Stream;
46            use interprocess::local_socket::Stream as LocalStream;
47
48            loop {
49                let name = match ipc_name() {
50                    Ok(name) => name,
51                    Err(e) => {
52                        tracing::error!("Failed to get IPC name in daemon: {e}");
53                        continue;
54                    }
55                };
56                if let Ok(stream) = LocalStream::connect(name) {
57                    let mut reader = BufReader::new(stream);
58                    let mut buf = String::new();
59                    if let Err(e) = reader.read_line(&mut buf) {
60                        tracing::error!("Failed to read line from IPC stream: {e}");
61                        continue;
62                    }
63                    let mut req: Request = match serde_json::from_str(&buf) {
64                        Ok(req) => req,
65                        Err(e) => {
66                            tracing::error!("Failed to parse request JSON: {e}");
67                            continue;
68                        }
69                    };
70
71                    req = match req {
72                        Request::ReIsq(x) => Request::ReIsq(x),
73                        Request::Terminate => Request::Terminate,
74                        Request::Detokenize(mut x) => {
75                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
76                            x.response = sender;
77                            let req = Request::Detokenize(x);
78
79                            if request_sender.send(req).await.is_err() {
80                                tracing::error!("Daemon channel closed for Detokenize request");
81                                continue;
82                            }
83                            match receiver.recv().await {
84                                Some(resp) => {
85                                    if let Err(e) = resp {
86                                        tracing::error!("Detokenize response error: {e}");
87                                    }
88                                }
89                                None => tracing::error!("Detokenize response channel closed"),
90                            }
91                            continue;
92                        }
93                        Request::Tokenize(mut x) => {
94                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
95                            x.response = sender;
96                            let req = Request::Tokenize(x);
97
98                            if request_sender.send(req).await.is_err() {
99                                tracing::error!("Daemon channel closed for Tokenize request");
100                                continue;
101                            }
102                            match receiver.recv().await {
103                                Some(resp) => {
104                                    if let Err(e) = resp {
105                                        tracing::error!("Tokenize response error: {e}");
106                                    }
107                                }
108                                None => tracing::error!("Tokenize response channel closed"),
109                            }
110                            continue;
111                        }
112                        Request::Normal(mut x) => {
113                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
114                            x.is_streaming = false;
115                            x.response = sender;
116                            let req = Request::Normal(x);
117
118                            if request_sender.send(req).await.is_err() {
119                                tracing::error!("Daemon channel closed for Normal request");
120                                continue;
121                            }
122                            match receiver.recv().await {
123                                Some(resp) => {
124                                    if let Err(e) = resp.as_result() {
125                                        tracing::error!("Normal response error: {e}");
126                                    }
127                                }
128                                None => tracing::error!("Normal response channel closed"),
129                            }
130                            continue;
131                        }
132                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
133                    };
134
135                    if request_sender.send(req).await.is_err() {
136                        tracing::error!("Daemon channel closed for request");
137                    }
138                }
139            }
140        });
141    });
142}
143
144pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
145    use std::io::BufRead;
146    use std::io::BufReader;
147
148    let ring_config = RingConfig::load();
149
150    let master_ip = ring_config.master_ip();
151    let master_port = ring_config.master_port;
152    std::thread::spawn(move || {
153        let rt = Runtime::new().unwrap();
154        rt.block_on(async move {
155            loop {
156                if let Ok(stream) = TcpStream::connect(format!("{master_ip}:{master_port}")) {
157                    let mut reader = BufReader::new(stream);
158                    let mut buf = String::new();
159                    reader.read_line(&mut buf).unwrap();
160                    let mut req: Request = serde_json::from_str(&buf).unwrap();
161
162                    req = match req {
163                        Request::ReIsq(x) => Request::ReIsq(x),
164                        Request::Terminate => Request::Terminate,
165                        Request::Detokenize(mut x) => {
166                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
167                            x.response = sender;
168                            let req = Request::Detokenize(x);
169
170                            request_sender.send(req).await.unwrap();
171                            let resp = receiver.recv().await.unwrap();
172                            resp.unwrap();
173                            continue;
174                        }
175                        Request::Tokenize(mut x) => {
176                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
177                            x.response = sender;
178                            let req = Request::Tokenize(x);
179
180                            request_sender.send(req).await.unwrap();
181                            let resp = receiver.recv().await.unwrap();
182                            resp.unwrap();
183                            continue;
184                        }
185                        Request::Normal(mut x) => {
186                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
187                            x.is_streaming = false;
188                            x.response = sender;
189                            let req = Request::Normal(x);
190
191                            request_sender.send(req).await.unwrap();
192                            let resp = receiver.recv().await.unwrap();
193                            resp.as_result().unwrap();
194                            continue;
195                        }
196                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
197                    };
198
199                    request_sender.send(req).await.unwrap();
200                }
201            }
202        });
203    });
204}
205
206#[derive(Serialize, Deserialize, Debug)]
207#[serde(transparent)]
208pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
209
210#[derive(Serialize, Deserialize, Debug)]
211pub(crate) enum WorkerTransferData {
212    Init {
213        id: BigCCharArray,
214        worker_rank: usize,
215    },
216}
217
218pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
219    let printname = "mistralrs_daemon.sock";
220    Ok(printname.to_ns_name::<GenericNamespaced>()?)
221}
222
223#[allow(clippy::too_many_arguments)]
224pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
225    dtype: DType,
226    device: &Device,
227    available_devices: &[Device],
228    silent: bool,
229    config: &str,
230    loading_isq: bool,
231    from_uqff: bool,
232    organization: IsqOrganization,
233    model: &T,
234    paths: &dyn ModelPaths,
235) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
236    if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
237        tracing::warn!(
238            "Distributed support was not included in the build, be sure to build with `--features nccl`."
239        );
240    }
241
242    // NCCL case!
243
244    let local_world_size = available_devices.len();
245    let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
246        usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
247    } else {
248        // global world size is always >= local world size
249        std::cmp::max(
250            mistralrs_quant::distributed::get_global_tp_size_from_devices()?,
251            local_world_size,
252        )
253    };
254
255    let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
256    if use_multi_node {
257        info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
258    }
259
260    if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
261        anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
262    }
263
264    info!("Local tensor parallel world size is {local_world_size}");
265    info!("Global tensor parallel world size is {global_world_size}");
266
267    // TP uses parallel pipelines.
268    let name = ipc_name()?;
269    let mut id;
270    let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
271        let payload: WorkerTransferData = serde_json::from_str(&payload)?;
272        let WorkerTransferData::Init {
273            id: new_id,
274            worker_rank,
275        } = payload;
276        id = mistralrs_quant::Id::uninit(new_id.0);
277
278        let mut stream = LocalStream::connect(name)?;
279        stream.write_all(b"ready\n")?;
280        worker_rank + 1
281    } else if cfg!(feature = "ring") {
282        id = mistralrs_quant::Id::new();
283
284        let config = RingConfig::load();
285
286        config.rank
287    } else {
288        id = mistralrs_quant::Id::new();
289        let num_ranks = mistralrs_quant::distributed::get_global_tp_size_from_devices()?;
290        let num_workers = num_ranks - 1;
291        let mut children = Vec::new();
292        for worker_rank in 0..num_workers {
293            let exe_path = env::current_exe().expect("Failed to get current exe");
294
295            let args: Vec<String> = env::args().collect();
296
297            let mut cmd = Command::new(exe_path);
298            cmd.args(&args[1..]);
299
300            let data = WorkerTransferData::Init {
301                id: BigCCharArray(*id.internal()),
302                worker_rank,
303            };
304
305            cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
306
307            cmd.stdout(std::process::Stdio::null());
308            cmd.stderr(std::process::Stdio::null());
309            cmd.stdin(std::process::Stdio::null());
310
311            children.push(cmd.spawn().expect("Failed to spawn process"));
312        }
313
314        let listener = ListenerOptions::new().name(name).create_sync()?;
315        let mut ready_count = 0;
316
317        while ready_count < num_workers {
318            let stream = listener.accept()?;
319            let mut reader = BufReader::new(stream);
320            let mut message = String::new();
321            reader.read_line(&mut message)?;
322            if message.trim() == "ready" {
323                ready_count += 1;
324            }
325        }
326        info!("All workers have received the ids!");
327
328        0
329    };
330
331    if use_multi_node {
332        if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
333            let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
334            info!("Head node managing {n_nodes} workers.");
335            let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
336                anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
337            };
338            info!("Head node initializing connection on {port}.");
339            let server = mistralrs_quant::Server::new(
340                &format!("0.0.0.0:{port}"),
341                n_nodes,
342                local_world_size,
343            )?;
344
345            server.broadcast_id(&id)?;
346        } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
347            info!("Worker node connecting to {addr}.");
348            let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
349
350            id = client.receive_id()?;
351        }
352    }
353
354    let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
355        let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
356            anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
357        };
358        let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
359        info!("Worker ID is {node_id}.");
360        (node_id + 1) * local_world_size
361    } else {
362        0
363    };
364
365    // They each block on each other
366    // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html?ncclcomminitrank#ncclcomminitrank
367    let comm = mistralrs_quant::Comm::from_device(
368        id,
369        device,
370        local_rank + rank_offset,
371        global_world_size,
372    )?;
373
374    let make_dummy_regexes = if loading_isq && from_uqff {
375        // Dummy weights for the layers which will be overwritten...
376        Some(std::sync::Arc::new(
377            if matches!(organization, IsqOrganization::MoeExpertsOnly) {
378                model.isq_layer_regexes_moqe(config)?
379            } else {
380                model.isq_layer_regexes(config)?
381            },
382        ))
383    } else {
384        None
385    };
386
387    let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
388        paths.get_weight_filenames().to_vec(),
389        vec![],
390        Some(dtype),
391        &Device::Cpu,
392        vec![],
393        silent,
394        make_dummy_regexes,
395        |_| true,
396        Arc::new(|_| DeviceForLoadTensor::Base),
397    )?;
398
399    info!("Loading all ranks.");
400    // The mapper is specific to this pipeline
401    let mapper = DeviceMapSetting::Nccl {
402        nm_device: available_devices[0].clone(),
403        comm: Arc::new(comm),
404    }
405    .into_mapper(model.num_layers(config)?, device, None)?;
406
407    let sharded_vb = if !loading_isq {
408        sharded_vb.clone().set_device(device.clone())
409    } else {
410        sharded_vb.clone()
411    };
412
413    Ok((mapper, sharded_vb))
414}