mistralrs_core/
distributed.rs

1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{RingConfig, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29    if cfg!(feature = "cuda") && !cfg!(feature = "ring") {
30        std::env::var(IS_DAEMON_FLAG).is_ok()
31    } else if cfg!(feature = "ring") {
32        !RingConfig::load().is_master_rank()
33    } else {
34        false
35    }
36}
37
38pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
39    use std::io::BufRead;
40    use std::io::BufReader;
41
42    std::thread::spawn(move || {
43        let rt = Runtime::new().unwrap();
44        rt.block_on(async move {
45            use interprocess::local_socket::traits::Stream;
46            use interprocess::local_socket::Stream as LocalStream;
47
48            loop {
49                let name = match ipc_name() {
50                    Ok(name) => name,
51                    Err(e) => {
52                        tracing::error!("Failed to get IPC name in daemon: {e}");
53                        continue;
54                    }
55                };
56                if let Ok(stream) = LocalStream::connect(name) {
57                    let mut reader = BufReader::new(stream);
58                    let mut buf = String::new();
59                    if let Err(e) = reader.read_line(&mut buf) {
60                        tracing::error!("Failed to read line from IPC stream: {e}");
61                        continue;
62                    }
63                    let mut req: Request = match serde_json::from_str(&buf) {
64                        Ok(req) => req,
65                        Err(e) => {
66                            tracing::error!("Failed to parse request JSON: {e}");
67                            continue;
68                        }
69                    };
70
71                    req = match req {
72                        Request::ReIsq(x) => Request::ReIsq(x),
73                        Request::Terminate => Request::Terminate,
74                        Request::Detokenize(mut x) => {
75                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
76                            x.response = sender;
77                            let req = Request::Detokenize(x);
78
79                            if request_sender.send(req).await.is_err() {
80                                tracing::error!("Daemon channel closed for Detokenize request");
81                                continue;
82                            }
83                            match receiver.recv().await {
84                                Some(resp) => {
85                                    if let Err(e) = resp {
86                                        tracing::error!("Detokenize response error: {e}");
87                                    }
88                                }
89                                None => tracing::error!("Detokenize response channel closed"),
90                            }
91                            continue;
92                        }
93                        Request::Tokenize(mut x) => {
94                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
95                            x.response = sender;
96                            let req = Request::Tokenize(x);
97
98                            if request_sender.send(req).await.is_err() {
99                                tracing::error!("Daemon channel closed for Tokenize request");
100                                continue;
101                            }
102                            match receiver.recv().await {
103                                Some(resp) => {
104                                    if let Err(e) = resp {
105                                        tracing::error!("Tokenize response error: {e}");
106                                    }
107                                }
108                                None => tracing::error!("Tokenize response channel closed"),
109                            }
110                            continue;
111                        }
112                        Request::Normal(mut x) => {
113                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
114                            x.is_streaming = false;
115                            x.response = sender;
116                            let req = Request::Normal(x);
117
118                            if request_sender.send(req).await.is_err() {
119                                tracing::error!("Daemon channel closed for Normal request");
120                                continue;
121                            }
122                            match receiver.recv().await {
123                                Some(resp) => {
124                                    if let Err(e) = resp.as_result() {
125                                        tracing::error!("Normal response error: {e}");
126                                    }
127                                }
128                                None => tracing::error!("Normal response channel closed"),
129                            }
130                            continue;
131                        }
132                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
133                    };
134
135                    if request_sender.send(req).await.is_err() {
136                        tracing::error!("Daemon channel closed for request");
137                    }
138                }
139            }
140        });
141    });
142}
143
144pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
145    use std::io::BufRead;
146    use std::io::BufReader;
147
148    let ring_config = RingConfig::load();
149
150    let master_ip = ring_config.master_ip();
151    let master_port = ring_config.master_port;
152    std::thread::spawn(move || {
153        let rt = Runtime::new().unwrap();
154        rt.block_on(async move {
155            loop {
156                if let Ok(stream) = TcpStream::connect(format!("{master_ip}:{master_port}")) {
157                    let mut reader = BufReader::new(stream);
158                    let mut buf = String::new();
159                    reader.read_line(&mut buf).unwrap();
160                    let mut req: Request = serde_json::from_str(&buf).unwrap();
161
162                    req = match req {
163                        Request::ReIsq(x) => Request::ReIsq(x),
164                        Request::Terminate => Request::Terminate,
165                        Request::Detokenize(mut x) => {
166                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
167                            x.response = sender;
168                            let req = Request::Detokenize(x);
169
170                            request_sender.send(req).await.unwrap();
171                            let resp = receiver.recv().await.unwrap();
172                            resp.unwrap();
173                            continue;
174                        }
175                        Request::Tokenize(mut x) => {
176                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
177                            x.response = sender;
178                            let req = Request::Tokenize(x);
179
180                            request_sender.send(req).await.unwrap();
181                            let resp = receiver.recv().await.unwrap();
182                            resp.unwrap();
183                            continue;
184                        }
185                        Request::Normal(mut x) => {
186                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
187                            x.is_streaming = false;
188                            x.response = sender;
189                            let req = Request::Normal(x);
190
191                            request_sender.send(req).await.unwrap();
192                            let resp = receiver.recv().await.unwrap();
193                            resp.as_result().unwrap();
194                            continue;
195                        }
196                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
197                    };
198
199                    request_sender.send(req).await.unwrap();
200                }
201            }
202        });
203    });
204}
205
206#[derive(Serialize, Deserialize, Debug)]
207#[serde(transparent)]
208pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
209
210#[derive(Serialize, Deserialize, Debug)]
211pub(crate) enum WorkerTransferData {
212    Init {
213        id: BigCCharArray,
214        worker_rank: usize,
215    },
216}
217
218pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
219    let printname = "mistralrs_daemon.sock";
220    Ok(printname.to_ns_name::<GenericNamespaced>()?)
221}
222
223#[allow(clippy::too_many_arguments)]
224pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
225    dtype: DType,
226    device: &Device,
227    available_devices: &[Device],
228    silent: bool,
229    config: &str,
230    loading_isq: bool,
231    from_uqff: bool,
232    organization: IsqOrganization,
233    model: &T,
234    paths: &dyn ModelPaths,
235) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
236    if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
237        tracing::warn!(
238            "Distributed support was not included in the build, be sure to build with `--features nccl`."
239        );
240    }
241
242    // NCCL case!
243
244    let local_world_size = available_devices.len();
245    let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
246        usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
247    } else {
248        mistralrs_quant::distributed::get_global_tp_size_from_devices()?
249    };
250
251    let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
252    if use_multi_node {
253        info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
254    }
255
256    if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
257        anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
258    }
259
260    info!("Local tensor parallel world size is {local_world_size}");
261    info!("Global tensor parallel world size is {global_world_size}");
262
263    // TP uses parallel pipelines.
264    let name = ipc_name()?;
265    let mut id;
266    let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
267        let payload: WorkerTransferData = serde_json::from_str(&payload)?;
268        let WorkerTransferData::Init {
269            id: new_id,
270            worker_rank,
271        } = payload;
272        id = mistralrs_quant::Id::uninit(new_id.0);
273
274        let mut stream = LocalStream::connect(name)?;
275        stream.write_all(b"ready\n")?;
276        worker_rank + 1
277    } else if cfg!(feature = "ring") {
278        id = mistralrs_quant::Id::new();
279
280        let config = RingConfig::load();
281
282        config.rank
283    } else {
284        id = mistralrs_quant::Id::new();
285        let num_workers = mistralrs_quant::distributed::get_global_tp_size_from_devices()? - 1;
286        let mut children = Vec::new();
287        for worker_rank in 0..num_workers {
288            let exe_path = env::current_exe().expect("Failed to get current exe");
289
290            let args: Vec<String> = env::args().collect();
291
292            let mut cmd = Command::new(exe_path);
293            cmd.args(&args[1..]);
294
295            let data = WorkerTransferData::Init {
296                id: BigCCharArray(*id.internal()),
297                worker_rank,
298            };
299
300            cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
301
302            cmd.stdout(std::process::Stdio::null());
303            cmd.stderr(std::process::Stdio::null());
304            cmd.stdin(std::process::Stdio::null());
305
306            children.push(cmd.spawn().expect("Failed to spawn process"));
307        }
308
309        let listener = ListenerOptions::new().name(name).create_sync()?;
310        let mut ready_count = 0;
311
312        while ready_count < num_workers {
313            let stream = listener.accept()?;
314            let mut reader = BufReader::new(stream);
315            let mut message = String::new();
316            reader.read_line(&mut message)?;
317            if message.trim() == "ready" {
318                ready_count += 1;
319            }
320        }
321        info!("All workers have received the ids!");
322
323        0
324    };
325
326    if use_multi_node {
327        if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
328            let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
329            info!("Head node managing {n_nodes} workers.");
330            let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
331                anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
332            };
333            info!("Head node initializing connection on {port}.");
334            let server = mistralrs_quant::Server::new(
335                &format!("0.0.0.0:{port}"),
336                n_nodes,
337                local_world_size,
338            )?;
339
340            server.broadcast_id(&id)?;
341        } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
342            info!("Worker node connecting to {addr}.");
343            let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
344
345            id = client.receive_id()?;
346        }
347    }
348
349    let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
350        let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
351            anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
352        };
353        let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
354        info!("Worker ID is {node_id}.");
355        (node_id + 1) * local_world_size
356    } else {
357        0
358    };
359
360    // They each block on each other
361    // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html?ncclcomminitrank#ncclcomminitrank
362    let comm = mistralrs_quant::Comm::from_device(
363        id,
364        device,
365        local_rank + rank_offset,
366        global_world_size,
367    )?;
368
369    let make_dummy_regexes = if loading_isq && from_uqff {
370        // Dummy weights for the layers which will be overwritten...
371        Some(std::sync::Arc::new(
372            if matches!(organization, IsqOrganization::MoeExpertsOnly) {
373                model.isq_layer_regexes_moqe(config)?
374            } else {
375                model.isq_layer_regexes(config)?
376            },
377        ))
378    } else {
379        None
380    };
381
382    let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
383        paths.get_weight_filenames().to_vec(),
384        vec![],
385        Some(dtype),
386        &Device::Cpu,
387        vec![],
388        silent,
389        make_dummy_regexes,
390        |_| true,
391        Arc::new(|_| DeviceForLoadTensor::Base),
392    )?;
393
394    info!("Loading all ranks.");
395    // The mapper is specific to this pipeline
396    let mapper = DeviceMapSetting::Nccl {
397        nm_device: available_devices[0].clone(),
398        comm: Arc::new(comm),
399    }
400    .into_mapper(model.num_layers(config)?, device, None)?;
401
402    let sharded_vb = if !loading_isq {
403        sharded_vb.clone().set_device(device.clone())
404    } else {
405        sharded_vb.clone()
406    };
407
408    Ok((mapper, sharded_vb))
409}