mistralrs_core/
distributed.rs

1use anyhow::Context;
2use candle_core::{DType, Device};
3use core::ffi::c_char;
4use interprocess::local_socket::traits::{Listener, Stream};
5use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
6use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7pub use mistralrs_quant::distributed::use_nccl;
8use mistralrs_quant::{ShardedSafeTensors, ShardedVarBuilder};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::process::Command;
14use std::str::FromStr;
15use std::sync::Arc;
16use tracing::info;
17
18use crate::device_map::DeviceMapper;
19use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
20use crate::{DeviceMapSetting, IsqOrganization, ModelPaths};
21
22pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL";
23
24pub fn is_daemon() -> bool {
25    std::env::var(IS_DAEMON_FLAG).is_ok()
26}
27
28#[derive(Serialize, Deserialize, Debug)]
29#[serde(transparent)]
30pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
31
32#[derive(Serialize, Deserialize, Debug)]
33pub(crate) enum WorkerTransferData {
34    Init {
35        id: BigCCharArray,
36        worker_rank: usize,
37    },
38}
39
40pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
41    let printname = "mistralrs_daemon.sock";
42    Ok(printname.to_ns_name::<GenericNamespaced>()?)
43}
44
45#[allow(clippy::too_many_arguments)]
46pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
47    dtype: DType,
48    device: &Device,
49    load_device: &Device,
50    available_devices: &[Device],
51    config: &str,
52    loading_isq: bool,
53    from_uqff: bool,
54    organization: IsqOrganization,
55    model: &T,
56    paths: &dyn ModelPaths,
57) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
58    #[cfg(not(feature = "nccl"))]
59    tracing::warn!(
60        "NCCL support was included in the build, be sure to build with `--features nccl`."
61    );
62
63    // NCCL case!
64
65    let local_world_size = available_devices.len();
66    let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
67        usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
68    } else {
69        mistralrs_quant::distributed::get_global_tp_size_from_devices()?
70    };
71
72    let use_multi_node = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE").is_ok();
73    if use_multi_node {
74        info!("MISTRALRS_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
75    }
76
77    if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
78        anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
79    }
80
81    info!("Local tensor parallel world size is {local_world_size}");
82    info!("Global tensor parallel world size is {global_world_size}");
83
84    // TP uses parallel pipelines.
85    let name = ipc_name()?;
86    let mut id;
87    let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
88        let payload: WorkerTransferData = serde_json::from_str(&payload)?;
89        let WorkerTransferData::Init {
90            id: new_id,
91            worker_rank,
92        } = payload;
93        id = mistralrs_quant::Id::uninit(new_id.0);
94
95        let mut stream = LocalStream::connect(name)?;
96        stream.write_all(b"ready\n")?;
97        worker_rank + 1
98    } else {
99        id = mistralrs_quant::Id::new();
100        let num_workers = mistralrs_quant::distributed::get_global_tp_size_from_devices()? - 1;
101        let mut children = Vec::new();
102        for worker_rank in 0..num_workers {
103            let exe_path = env::current_exe().expect("Failed to get current exe");
104
105            let args: Vec<String> = env::args().collect();
106
107            let mut cmd = Command::new(exe_path);
108            cmd.args(&args[1..]);
109
110            let data = WorkerTransferData::Init {
111                id: BigCCharArray(*id.internal()),
112                worker_rank,
113            };
114
115            cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
116
117            cmd.stdout(std::process::Stdio::null());
118            cmd.stderr(std::process::Stdio::null());
119            cmd.stdin(std::process::Stdio::null());
120
121            children.push(cmd.spawn().expect("Failed to spawn process"));
122        }
123
124        let listener = ListenerOptions::new().name(name).create_sync()?;
125        let mut ready_count = 0;
126
127        while ready_count < num_workers {
128            let stream = listener.accept()?;
129            let mut reader = BufReader::new(stream);
130            let mut message = String::new();
131            reader.read_line(&mut message)?;
132            if message.trim() == "ready" {
133                ready_count += 1;
134            }
135        }
136        info!("All workers have received the ids!");
137
138        0
139    };
140
141    if use_multi_node {
142        if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
143            let n_nodes = usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
144            info!("Head node managing {n_nodes} workers.");
145            let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
146                anyhow::bail!("Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT");
147            };
148            info!("Head node initializing connection on {port}.");
149            let server = mistralrs_quant::Server::new(
150                &format!("0.0.0.0:{port}"),
151                n_nodes,
152                local_world_size,
153            )?;
154
155            server.broadcast_id(&id)?;
156        } else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
157            info!("Worker node connecting to {addr}.");
158            let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
159
160            id = client.receive_id()?;
161        }
162    }
163
164    let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
165        let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
166            anyhow::bail!("Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID");
167        };
168        let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
169        info!("Worker ID is {node_id}.");
170        (node_id + 1) * local_world_size
171    } else {
172        0
173    };
174
175    // They each block on each other
176    // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html?ncclcomminitrank#ncclcomminitrank
177    let comm = mistralrs_quant::Comm::from_device(
178        id,
179        device,
180        local_rank + rank_offset,
181        global_world_size,
182    )?;
183
184    let make_dummy_regexes = if loading_isq && from_uqff {
185        // Dummy weights for the layers which will be overwritten...
186        Some(std::sync::Arc::new(
187            if matches!(organization, IsqOrganization::MoeExpertsOnly) {
188                model.isq_layer_regexes_moqe(config)?
189            } else {
190                model.isq_layer_regexes(config)?
191            },
192        ))
193    } else {
194        None
195    };
196
197    let sharded_vb = unsafe {
198        ShardedSafeTensors::sharded(
199            paths.get_weight_filenames(),
200            dtype,
201            load_device,
202            make_dummy_regexes,
203        )?
204    };
205
206    info!("Loading all ranks.");
207    // The mapper is specific to this pipeline
208    let mapper = DeviceMapSetting::Nccl {
209        nm_device: available_devices[0].clone(),
210        comm: Arc::new(comm),
211    }
212    .into_mapper(model.num_layers(config)?, device, None)?;
213
214    let sharded_vb = if !loading_isq {
215        sharded_vb.clone().set_device(device.clone())
216    } else {
217        sharded_vb.clone()
218    };
219
220    Ok((mapper, sharded_vb))
221}