mistralrs_core/
device_map.rs

1use std::{fmt::Debug, sync::Arc};
2
3use crate::{
4    pipeline::AutoDeviceMapParams,
5    utils::{debug::DeviceRepr, log::once_log_info},
6    MemoryUsage, Topology, TryIntoDType,
7};
8use candle_core::{DType, Device, DeviceLocation, Result, Tensor};
9use mistralrs_quant::ShardedVarBuilder;
10use serde::Deserialize;
11use tracing::info;
12
13#[derive(Debug, Default, Deserialize, Clone)]
14pub struct DeviceLayerMapMetadata {
15    pub ordinal: usize,
16    pub layers: usize,
17}
18
19#[derive(Debug, Clone)]
20pub enum DeviceMapSetting {
21    /// Manual device mapping.
22    Map(DeviceMapMetadata),
23    /// Automatic device mapping (recommended).
24    Auto(AutoDeviceMapParams),
25    /// Dummy device mapping for a NCCL pipeline
26    DummyNccl { nm_device: Device },
27    /// Real device mapping for a NCCL pipeline
28    Nccl {
29        nm_device: Device,
30        comm: Arc<mistralrs_quant::Comm>,
31    },
32}
33
34#[derive(Debug, Default, Deserialize, Clone)]
35/// Metadata to initialize the device mapper.
36pub struct DeviceMapMetadata {
37    device_layers: Option<Vec<DeviceLayerMapMetadata>>,
38    host_layers: Option<usize>,
39}
40
41impl DeviceMapMetadata {
42    pub fn from_num_device_layers(device_layers: Vec<DeviceLayerMapMetadata>) -> Self {
43        Self {
44            device_layers: Some(device_layers),
45            host_layers: None,
46        }
47    }
48    /// A device mapper to not map device.
49    pub fn dummy() -> Self {
50        Self {
51            device_layers: None,
52            host_layers: None,
53        }
54    }
55}
56
57impl DeviceMapSetting {
58    /// A device mapper to not map device.
59    pub fn dummy() -> Self {
60        Self::Map(DeviceMapMetadata::dummy())
61    }
62    pub fn into_mapper(
63        &self,
64        model_layers: usize,
65        device: &Device,
66        topology: Option<&Topology>,
67    ) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
68        match self {
69            Self::Nccl { nm_device, comm } => {
70                once_log_info("Loading model using a NCCL-parallelized pipeline.");
71                Ok(Box::new(NcclDeviceMapper {
72                    nm_device: nm_device.clone(),
73                    model_layers,
74                    comm: Some(comm.clone()),
75                }))
76            }
77
78            Self::DummyNccl { nm_device } => {
79                once_log_info("Loading model using a NCCL-parallelized pipeline.");
80                Ok(Box::new(NcclDeviceMapper {
81                    nm_device: nm_device.clone(),
82                    model_layers,
83                    comm: None,
84                }))
85            }
86
87            Self::Map(DeviceMapMetadata {
88                device_layers,
89                host_layers,
90            }) => {
91                if let Some(topology) = topology {
92                    if topology.0.iter().all(|x| x.is_none()) {
93                        return Ok(Box::new(DummyDeviceMapper {
94                            nm_device: device.clone(),
95                        }));
96                    } else {
97                        let layers = topology
98                            .0
99                            .iter()
100                            .map(|layer| {
101                                layer
102                                    .as_ref()
103                                    .map(|x| x.device.clone().unwrap_or(device.clone()))
104                                    .unwrap_or(device.clone())
105                            })
106                            .collect::<Vec<_>>();
107
108                        info!("Loading model according to the following repeating layer mappings based on topology:");
109                        for (i, dev) in layers.iter().enumerate() {
110                            info!("Layer {i}: {}", dev.device_pretty_repr());
111                        }
112
113                        return Ok(Box::new(LayerDeviceMapper {
114                            mappings: layers,
115                            nm_device: device.clone(),
116                        }));
117                    }
118                }
119
120                // How many device layers
121                // Clamp to max of model layers
122                let n_device_layers = if let Some(layers) = &device_layers {
123                    layers
124                        .iter()
125                        .map(|metadata| metadata.layers)
126                        .sum::<usize>()
127                        .clamp(0, model_layers)
128                } else {
129                    return Ok(Box::new(DummyDeviceMapper {
130                        nm_device: device.clone(),
131                    }));
132                };
133                // How many host (cpu) layers, defaulting to automatically filling the rest.
134                // If n_device_layers > model_layers, n_host_layers = 0
135                let n_host_layers =
136                    host_layers.unwrap_or(model_layers.saturating_sub(n_device_layers));
137                if n_device_layers + n_host_layers != model_layers {
138                    candle_core::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
139                }
140                once_log_info(format!("Model has {model_layers} repeating layers."));
141
142                // Handle multi-GPU mapping here
143                let mut combined = Vec::with_capacity(model_layers);
144                if device_layers
145                    .as_ref()
146                    .is_some_and(|layers| layers.len() == 1)
147                {
148                    combined.extend(vec![device.clone(); n_device_layers]);
149                } else {
150                    let original_seed = if !device.is_cpu() {
151                        Some(device.get_current_seed()?)
152                    } else {
153                        None
154                    };
155                    for DeviceLayerMapMetadata { ordinal, layers } in
156                        device_layers.as_ref().unwrap()
157                    {
158                        let dev = match device.location() {
159                            DeviceLocation::Cpu => Device::Cpu,
160                            DeviceLocation::Cuda { gpu_id: device_ord } => {
161                                if device_ord == *ordinal {
162                                    device.clone()
163                                } else {
164                                    Device::new_cuda_with_stream(*ordinal)?
165                                }
166                            }
167                            DeviceLocation::Metal { gpu_id: device_ord } => {
168                                if device_ord == *ordinal {
169                                    device.clone()
170                                } else {
171                                    Device::new_metal(*ordinal)?
172                                }
173                            }
174                        };
175                        if !device.is_cpu() {
176                            dev.set_seed(original_seed.unwrap())?;
177                        }
178                        combined.extend(vec![dev; *layers]);
179                    }
180                }
181
182                // Always put the CPU layers at the end so that we reduce dtoh and htod copies
183                combined.extend(vec![Device::Cpu; n_host_layers]);
184
185                // Sanity
186                assert_eq!(combined.len(), model_layers);
187
188                // Print it out
189                {
190                    once_log_info(
191                        "Loading model according to the following repeating layer mappings:",
192                    );
193                    let mut start_index = 0;
194                    let mut current_dev = &combined[0];
195
196                    // Iterate starting from index 1 to detect when the variant changes
197                    for (i, variant) in combined.iter().enumerate().skip(1) {
198                        // If the variant changes, print the previous continuous block
199                        if !variant.same_device(current_dev) {
200                            once_log_info(format!(
201                                "Layers {}-{}: {} ({} GB)",
202                                start_index,
203                                i - 1,
204                                current_dev.device_pretty_repr(),
205                                MemoryUsage
206                                    .get_total_memory(current_dev)?
207                                    .div_ceil(1024 * 1024 * 1024),
208                            ));
209                            start_index = i; // start a new range
210                            current_dev = variant;
211                        }
212                    }
213
214                    once_log_info(format!(
215                        "Layers {}-{}: {} ({} GB)",
216                        start_index,
217                        combined.len() - 1,
218                        current_dev.device_pretty_repr(),
219                        MemoryUsage
220                            .get_total_memory(current_dev)?
221                            .div_ceil(1024 * 1024 * 1024),
222                    ));
223                }
224
225                Ok(Box::new(LayerDeviceMapper {
226                    mappings: combined,
227                    nm_device: device.clone(),
228                }))
229            }
230            Self::Auto(_) => {
231                candle_core::bail!(".into_mapper does not work on Auto device map, convert it to a Map with DeviceMappedModelLoader::get_device_layers")
232            }
233        }
234    }
235}
236
237pub trait DeviceMapper: Debug {
238    // === DURING RUNTIME ===
239    /// Map during runtime
240    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor>;
241
242    // === DURING LOADING TIME ===
243    /// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
244    fn set_device(
245        &self,
246        layer: usize,
247        varbuilder: ShardedVarBuilder,
248        loading_isq: bool,
249    ) -> ShardedVarBuilder;
250    /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
251    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
252    fn get_unique_devices(&self) -> Vec<Device>;
253    /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
254    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
255    /// Set non mapped layer device. This is for ISQ + device mapping support
256    /// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
257    fn set_nm_device(&self, varbuilder: ShardedVarBuilder, loading_isq: bool) -> ShardedVarBuilder;
258    fn num_device_mapping_layers(&self) -> usize;
259    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>>;
260
261    // === IMMEDIATELY AFTER INIT ===
262    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
263}
264
265#[derive(Debug)]
266/// A device mapper which does device mapping per hidden layer.
267pub struct LayerDeviceMapper {
268    mappings: Vec<Device>,
269    nm_device: Device,
270}
271
272impl DeviceMapper for LayerDeviceMapper {
273    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
274        input.to_device(&self.mappings[layer])
275    }
276    fn set_device<'a>(
277        &self,
278        layer: usize,
279        varbuilder: ShardedVarBuilder,
280        loading_isq: bool,
281    ) -> ShardedVarBuilder {
282        if loading_isq {
283            return varbuilder;
284        }
285        varbuilder.set_device(self.mappings[layer].clone())
286    }
287    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
288        if loading_isq {
289            return Some(&self.nm_device);
290        }
291        self.mappings.get(layer)
292    }
293    fn get_unique_devices(&self) -> Vec<Device> {
294        self.mappings.iter().fold(Vec::new(), |mut acc, device| {
295            if !acc.iter().any(|d| d.same_device(device)) {
296                acc.push(device.clone());
297            }
298            acc
299        })
300    }
301    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
302        if loading_isq {
303            x.to_device(&Device::Cpu)
304        } else {
305            x.to_device(&self.nm_device)
306        }
307    }
308    fn set_nm_device<'a>(
309        &self,
310        varbuilder: ShardedVarBuilder,
311        loading_isq: bool,
312    ) -> ShardedVarBuilder {
313        if loading_isq {
314            varbuilder
315        } else {
316            varbuilder.set_device(self.nm_device.clone())
317        }
318    }
319    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
320        dtype
321            .try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
322            .map_err(candle_core::Error::msg)
323    }
324    fn num_device_mapping_layers(&self) -> usize {
325        self.mappings.len()
326    }
327    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
328        let id = mistralrs_quant::Id::new();
329        Ok(Arc::new(mistralrs_quant::Comm::from_device(
330            id,
331            self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
332            0,
333            1,
334        )?))
335    }
336}
337
338#[derive(Debug)]
339pub struct DummyDeviceMapper {
340    nm_device: Device,
341}
342
343impl DeviceMapper for DummyDeviceMapper {
344    fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
345        Ok(input)
346    }
347    fn set_device<'a>(
348        &self,
349        _: usize,
350        varbuilder: ShardedVarBuilder,
351        loading_isq: bool,
352    ) -> ShardedVarBuilder {
353        if loading_isq {
354            varbuilder.set_device(Device::Cpu)
355        } else {
356            varbuilder.set_device(self.nm_device.clone())
357        }
358    }
359    fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
360        Some(&self.nm_device)
361    }
362    fn get_unique_devices(&self) -> Vec<Device> {
363        vec![self.nm_device.clone()]
364    }
365    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
366        if loading_isq {
367            x.to_device(&Device::Cpu)
368        } else {
369            x.to_device(&self.nm_device)
370        }
371    }
372    fn set_nm_device<'a>(
373        &self,
374        varbuilder: ShardedVarBuilder,
375        loading_isq: bool,
376    ) -> ShardedVarBuilder {
377        if loading_isq {
378            varbuilder.set_device(Device::Cpu)
379        } else {
380            varbuilder.set_device(self.nm_device.clone())
381        }
382    }
383    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
384        dtype
385            .try_into_dtype(&[&self.nm_device])
386            .map_err(candle_core::Error::msg)
387    }
388    fn num_device_mapping_layers(&self) -> usize {
389        // Effectively one layer
390        1
391    }
392    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
393        let id = mistralrs_quant::Id::new();
394        Ok(Arc::new(mistralrs_quant::Comm::from_device(
395            id,
396            self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
397            0,
398            1,
399        )?))
400    }
401}
402
403#[derive(Debug)]
404pub struct NcclDeviceMapper {
405    nm_device: Device,
406    model_layers: usize,
407    comm: Option<Arc<mistralrs_quant::Comm>>,
408}
409
410impl DeviceMapper for NcclDeviceMapper {
411    fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
412        Ok(input)
413    }
414    fn set_device<'a>(
415        &self,
416        _: usize,
417        varbuilder: ShardedVarBuilder,
418        loading_isq: bool,
419    ) -> ShardedVarBuilder {
420        if loading_isq {
421            varbuilder.set_device(Device::Cpu)
422        } else {
423            varbuilder.set_device(self.nm_device.clone())
424        }
425    }
426    fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
427        Some(&self.nm_device)
428    }
429    fn get_unique_devices(&self) -> Vec<Device> {
430        vec![self.nm_device.clone()]
431    }
432    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
433        if loading_isq {
434            x.to_device(&Device::Cpu)
435        } else {
436            x.to_device(&self.nm_device)
437        }
438    }
439    fn set_nm_device<'a>(
440        &self,
441        varbuilder: ShardedVarBuilder,
442        loading_isq: bool,
443    ) -> ShardedVarBuilder {
444        if loading_isq {
445            varbuilder.set_device(Device::Cpu)
446        } else {
447            varbuilder.set_device(self.nm_device.clone())
448        }
449    }
450    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
451        dtype
452            .try_into_dtype(&[&self.nm_device])
453            .map_err(candle_core::Error::msg)
454    }
455    fn num_device_mapping_layers(&self) -> usize {
456        self.model_layers
457    }
458    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
459        if let Some(comm) = &self.comm {
460            Ok(comm.clone())
461        } else {
462            let id = mistralrs_quant::Id::new();
463            Ok(Arc::new(mistralrs_quant::Comm::from_device(
464                id,
465                self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
466                0,
467                1,
468            )?))
469        }
470    }
471}
472
473#[derive(Debug)]
474/// A device mapper which does device mapping per hidden layer.
475pub struct NcclPipelineParallelMapper {
476    mappings: Vec<(Arc<mistralrs_quant::Comm>, Device)>,
477    nm_device: Device,
478}
479
480impl DeviceMapper for NcclPipelineParallelMapper {
481    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
482        input.to_device(&self.mappings[layer].1)
483    }
484    fn set_device<'a>(
485        &self,
486        layer: usize,
487        varbuilder: ShardedVarBuilder,
488        loading_isq: bool,
489    ) -> ShardedVarBuilder {
490        if loading_isq {
491            return varbuilder;
492        }
493        varbuilder.set_device(self.mappings[layer].1.clone())
494    }
495    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
496        if loading_isq {
497            return Some(&self.nm_device);
498        }
499        self.mappings.get(layer).map(|(_, x)| x)
500    }
501    fn get_unique_devices(&self) -> Vec<Device> {
502        self.mappings
503            .iter()
504            .fold(Vec::new(), |mut acc, (_, device)| {
505                if !acc.iter().any(|d| d.same_device(device)) {
506                    acc.push(device.clone());
507                }
508                acc
509            })
510    }
511    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
512        if loading_isq {
513            x.to_device(&Device::Cpu)
514        } else {
515            x.to_device(&self.nm_device)
516        }
517    }
518    fn set_nm_device<'a>(
519        &self,
520        varbuilder: ShardedVarBuilder,
521        loading_isq: bool,
522    ) -> ShardedVarBuilder {
523        if loading_isq {
524            varbuilder
525        } else {
526            varbuilder.set_device(self.nm_device.clone())
527        }
528    }
529    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
530        dtype
531            .try_into_dtype(&self.mappings.iter().map(|(_, x)| x).collect::<Vec<_>>())
532            .map_err(candle_core::Error::msg)
533    }
534    fn num_device_mapping_layers(&self) -> usize {
535        self.mappings.len()
536    }
537    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
538        Ok(self.mappings[layer_idx].0.clone())
539    }
540}
541
542/// Get all devices on the same device type but different ordinals
543pub fn get_all_similar_devices(base: &Device) -> Result<Vec<Device>> {
544    let mut devices = Vec::new();
545    match base {
546        Device::Cpu => return Ok(vec![Device::Cpu]),
547        Device::Cuda(_) => {
548            let mut ord = 0;
549            let DeviceLocation::Cuda { gpu_id: base_ord } = base.location() else {
550                candle_core::bail!("location and device do not match");
551            };
552            loop {
553                if base_ord == ord {
554                    devices.push(base.clone());
555                    ord += 1;
556                    continue;
557                }
558                // Needs to be without a stream as PagedAttention doesn't like it otherwise.
559                if let Ok(dev) = Device::new_cuda(ord) {
560                    devices.push(dev);
561                    ord += 1;
562                } else {
563                    break;
564                }
565            }
566        }
567        #[cfg(not(feature = "metal"))]
568        Device::Metal(_) => {
569            candle_core::bail!("Not compiled with metal features, but have a metal device.");
570        }
571        #[cfg(feature = "metal")]
572        Device::Metal(_) => {
573            let total_ords = metal::Device::all().len();
574            let mut ord = 0;
575            let DeviceLocation::Metal { gpu_id: base_ord } = base.location() else {
576                candle_core::bail!("location and device do not match");
577            };
578            loop {
579                if base_ord == ord {
580                    devices.push(base.clone());
581                    ord += 1;
582                    continue;
583                }
584                if total_ords == ord {
585                    break;
586                }
587                if let Ok(dev) = Device::new_metal(ord) {
588                    devices.push(dev);
589                    ord += 1;
590                } else {
591                    break;
592                }
593            }
594        }
595    }
596    Ok(devices)
597}