mistralrs_core/
device_map.rs

1use std::{fmt::Debug, sync::Arc};
2
3use crate::{
4    pipeline::AutoDeviceMapParams, utils::debug::DeviceRepr, MemoryUsage, Topology, TryIntoDType,
5};
6use candle_core::{DType, Device, DeviceLocation, Result, Tensor};
7use mistralrs_quant::log::once_log_info;
8use mistralrs_quant::ShardedVarBuilder;
9use serde::Deserialize;
10use tracing::info;
11
12#[derive(Debug, Default, Deserialize, Clone)]
13pub struct DeviceLayerMapMetadata {
14    pub ordinal: usize,
15    pub layers: usize,
16}
17
18#[derive(Debug, Clone)]
19pub enum DeviceMapSetting {
20    /// Manual device mapping.
21    Map(DeviceMapMetadata),
22    /// Automatic device mapping (recommended).
23    Auto(AutoDeviceMapParams),
24    /// Dummy device mapping for a NCCL pipeline
25    DummyNccl { nm_device: Device },
26    /// Real device mapping for a NCCL pipeline
27    Nccl {
28        nm_device: Device,
29        comm: Arc<mistralrs_quant::Comm>,
30    },
31}
32
33#[derive(Debug, Default, Deserialize, Clone)]
34/// Metadata to initialize the device mapper.
35pub struct DeviceMapMetadata {
36    device_layers: Option<Vec<DeviceLayerMapMetadata>>,
37    host_layers: Option<usize>,
38}
39
40impl DeviceMapMetadata {
41    pub fn from_num_device_layers(device_layers: Vec<DeviceLayerMapMetadata>) -> Self {
42        Self {
43            device_layers: Some(device_layers),
44            host_layers: None,
45        }
46    }
47    /// A device mapper to not map device.
48    pub fn dummy() -> Self {
49        Self {
50            device_layers: None,
51            host_layers: None,
52        }
53    }
54}
55
56impl DeviceMapSetting {
57    /// A device mapper to not map device.
58    pub fn dummy() -> Self {
59        Self::Map(DeviceMapMetadata::dummy())
60    }
61    pub fn into_mapper(
62        &self,
63        model_layers: usize,
64        device: &Device,
65        topology: Option<&Topology>,
66    ) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
67        match self {
68            Self::Nccl { nm_device, comm } => {
69                once_log_info("Loading model using a NCCL-parallelized pipeline.");
70                Ok(Box::new(NcclDeviceMapper {
71                    nm_device: nm_device.clone(),
72                    model_layers,
73                    comm: Some(comm.clone()),
74                }))
75            }
76
77            Self::DummyNccl { nm_device } => {
78                once_log_info("Loading model using a NCCL-parallelized pipeline.");
79                Ok(Box::new(NcclDeviceMapper {
80                    nm_device: nm_device.clone(),
81                    model_layers,
82                    comm: None,
83                }))
84            }
85
86            Self::Map(DeviceMapMetadata {
87                device_layers,
88                host_layers,
89            }) => {
90                if let Some(topology) = topology {
91                    if topology.0.iter().all(|x| x.is_none()) {
92                        return Ok(Box::new(DummyDeviceMapper {
93                            nm_device: device.clone(),
94                        }));
95                    } else {
96                        let layers = topology
97                            .0
98                            .iter()
99                            .map(|layer| {
100                                layer
101                                    .as_ref()
102                                    .map(|x| x.device.clone().unwrap_or(device.clone()))
103                                    .unwrap_or(device.clone())
104                            })
105                            .collect::<Vec<_>>();
106
107                        info!("Loading model according to the following repeating layer mappings based on topology:");
108                        for (i, dev) in layers.iter().enumerate() {
109                            info!("Layer {i}: {}", dev.device_pretty_repr());
110                        }
111
112                        return Ok(Box::new(LayerDeviceMapper {
113                            mappings: layers,
114                            nm_device: device.clone(),
115                        }));
116                    }
117                }
118
119                // How many device layers
120                // Clamp to max of model layers
121                let n_device_layers = if let Some(layers) = &device_layers {
122                    layers
123                        .iter()
124                        .map(|metadata| metadata.layers)
125                        .sum::<usize>()
126                        .clamp(0, model_layers)
127                } else {
128                    return Ok(Box::new(DummyDeviceMapper {
129                        nm_device: device.clone(),
130                    }));
131                };
132                // How many host (cpu) layers, defaulting to automatically filling the rest.
133                // If n_device_layers > model_layers, n_host_layers = 0
134                let n_host_layers =
135                    host_layers.unwrap_or(model_layers.saturating_sub(n_device_layers));
136                if n_device_layers + n_host_layers != model_layers {
137                    candle_core::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
138                }
139                once_log_info(format!("Model has {model_layers} repeating layers."));
140
141                // Handle multi-GPU mapping here
142                let mut combined = Vec::with_capacity(model_layers);
143                if device_layers
144                    .as_ref()
145                    .is_some_and(|layers| layers.len() == 1)
146                {
147                    combined.extend(vec![device.clone(); n_device_layers]);
148                } else {
149                    let original_seed = if !device.is_cpu() {
150                        Some(device.get_current_seed()?)
151                    } else {
152                        None
153                    };
154                    for DeviceLayerMapMetadata { ordinal, layers } in
155                        device_layers.as_ref().unwrap()
156                    {
157                        let dev = match device.location() {
158                            DeviceLocation::Cpu => Device::Cpu,
159                            DeviceLocation::Cuda { gpu_id: device_ord } => {
160                                if device_ord == *ordinal {
161                                    device.clone()
162                                } else {
163                                    Device::new_cuda_with_stream(*ordinal)?
164                                }
165                            }
166                            DeviceLocation::Metal { gpu_id: device_ord } => {
167                                if device_ord == *ordinal {
168                                    device.clone()
169                                } else {
170                                    Device::new_metal(*ordinal)?
171                                }
172                            }
173                        };
174                        if !device.is_cpu() {
175                            dev.set_seed(original_seed.unwrap())?;
176                        }
177                        combined.extend(vec![dev; *layers]);
178                    }
179                }
180
181                // Always put the CPU layers at the end so that we reduce dtoh and htod copies
182                combined.extend(vec![Device::Cpu; n_host_layers]);
183
184                // Sanity
185                assert_eq!(combined.len(), model_layers);
186
187                // Print it out
188                {
189                    once_log_info(
190                        "Loading model according to the following repeating layer mappings:",
191                    );
192                    let mut start_index = 0;
193                    let mut current_dev = &combined[0];
194
195                    // Iterate starting from index 1 to detect when the variant changes
196                    for (i, variant) in combined.iter().enumerate().skip(1) {
197                        // If the variant changes, print the previous continuous block
198                        if !variant.same_device(current_dev) {
199                            once_log_info(format!(
200                                "Layers {}-{}: {} ({} GB)",
201                                start_index,
202                                i - 1,
203                                current_dev.device_pretty_repr(),
204                                MemoryUsage
205                                    .get_total_memory(current_dev)?
206                                    .div_ceil(1024 * 1024 * 1024),
207                            ));
208                            start_index = i; // start a new range
209                            current_dev = variant;
210                        }
211                    }
212
213                    once_log_info(format!(
214                        "Layers {}-{}: {} ({} GB)",
215                        start_index,
216                        combined.len() - 1,
217                        current_dev.device_pretty_repr(),
218                        MemoryUsage
219                            .get_total_memory(current_dev)?
220                            .div_ceil(1024 * 1024 * 1024),
221                    ));
222                }
223
224                Ok(Box::new(LayerDeviceMapper {
225                    mappings: combined,
226                    nm_device: device.clone(),
227                }))
228            }
229            Self::Auto(_) => {
230                candle_core::bail!(".into_mapper does not work on Auto device map, convert it to a Map with DeviceMappedModelLoader::get_device_layers")
231            }
232        }
233    }
234}
235
236pub trait DeviceMapper: Debug {
237    // === DURING RUNTIME ===
238    /// Map during runtime
239    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor>;
240
241    // === DURING LOADING TIME ===
242    /// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
243    fn set_device(
244        &self,
245        layer: usize,
246        varbuilder: ShardedVarBuilder,
247        loading_isq: bool,
248    ) -> ShardedVarBuilder;
249    /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
250    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
251    fn get_unique_devices(&self) -> Vec<Device>;
252    /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
253    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
254    /// Set non mapped layer device. This is for ISQ + device mapping support
255    /// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
256    fn set_nm_device(&self, varbuilder: ShardedVarBuilder, loading_isq: bool) -> ShardedVarBuilder;
257    fn num_device_mapping_layers(&self) -> usize;
258    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>>;
259
260    // === IMMEDIATELY AFTER INIT ===
261    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
262}
263
264#[derive(Debug)]
265/// A device mapper which does device mapping per hidden layer.
266pub struct LayerDeviceMapper {
267    mappings: Vec<Device>,
268    nm_device: Device,
269}
270
271impl DeviceMapper for LayerDeviceMapper {
272    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
273        input.to_device(&self.mappings[layer])
274    }
275    fn set_device<'a>(
276        &self,
277        layer: usize,
278        varbuilder: ShardedVarBuilder,
279        loading_isq: bool,
280    ) -> ShardedVarBuilder {
281        if loading_isq {
282            return varbuilder;
283        }
284        varbuilder.set_device(self.mappings[layer].clone())
285    }
286    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
287        if loading_isq {
288            return Some(&self.nm_device);
289        }
290        self.mappings.get(layer)
291    }
292    fn get_unique_devices(&self) -> Vec<Device> {
293        self.mappings.iter().fold(Vec::new(), |mut acc, device| {
294            if !acc.iter().any(|d| d.same_device(device)) {
295                acc.push(device.clone());
296            }
297            acc
298        })
299    }
300    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
301        if loading_isq {
302            x.to_device(&Device::Cpu)
303        } else {
304            x.to_device(&self.nm_device)
305        }
306    }
307    fn set_nm_device<'a>(
308        &self,
309        varbuilder: ShardedVarBuilder,
310        loading_isq: bool,
311    ) -> ShardedVarBuilder {
312        if loading_isq {
313            varbuilder
314        } else {
315            varbuilder.set_device(self.nm_device.clone())
316        }
317    }
318    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
319        dtype
320            .try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
321            .map_err(candle_core::Error::msg)
322    }
323    fn num_device_mapping_layers(&self) -> usize {
324        self.mappings.len()
325    }
326    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
327        let id = mistralrs_quant::Id::new();
328        Ok(Arc::new(mistralrs_quant::Comm::from_device(
329            id,
330            self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
331            0,
332            1,
333        )?))
334    }
335}
336
337#[derive(Debug)]
338pub struct DummyDeviceMapper {
339    nm_device: Device,
340}
341
342impl DeviceMapper for DummyDeviceMapper {
343    fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
344        Ok(input)
345    }
346    fn set_device<'a>(
347        &self,
348        _: usize,
349        varbuilder: ShardedVarBuilder,
350        loading_isq: bool,
351    ) -> ShardedVarBuilder {
352        if loading_isq {
353            varbuilder.set_device(Device::Cpu)
354        } else {
355            varbuilder.set_device(self.nm_device.clone())
356        }
357    }
358    fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
359        Some(&self.nm_device)
360    }
361    fn get_unique_devices(&self) -> Vec<Device> {
362        vec![self.nm_device.clone()]
363    }
364    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
365        if loading_isq {
366            x.to_device(&Device::Cpu)
367        } else {
368            x.to_device(&self.nm_device)
369        }
370    }
371    fn set_nm_device<'a>(
372        &self,
373        varbuilder: ShardedVarBuilder,
374        loading_isq: bool,
375    ) -> ShardedVarBuilder {
376        if loading_isq {
377            varbuilder.set_device(Device::Cpu)
378        } else {
379            varbuilder.set_device(self.nm_device.clone())
380        }
381    }
382    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
383        dtype
384            .try_into_dtype(&[&self.nm_device])
385            .map_err(candle_core::Error::msg)
386    }
387    fn num_device_mapping_layers(&self) -> usize {
388        // Effectively one layer
389        1
390    }
391    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
392        let id = mistralrs_quant::Id::new();
393        Ok(Arc::new(mistralrs_quant::Comm::from_device(
394            id,
395            self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
396            0,
397            1,
398        )?))
399    }
400}
401
402#[derive(Debug)]
403pub struct NcclDeviceMapper {
404    nm_device: Device,
405    model_layers: usize,
406    comm: Option<Arc<mistralrs_quant::Comm>>,
407}
408
409impl DeviceMapper for NcclDeviceMapper {
410    fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
411        Ok(input)
412    }
413    fn set_device<'a>(
414        &self,
415        _: usize,
416        varbuilder: ShardedVarBuilder,
417        loading_isq: bool,
418    ) -> ShardedVarBuilder {
419        if loading_isq {
420            varbuilder.set_device(Device::Cpu)
421        } else {
422            varbuilder.set_device(self.nm_device.clone())
423        }
424    }
425    fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
426        Some(&self.nm_device)
427    }
428    fn get_unique_devices(&self) -> Vec<Device> {
429        vec![self.nm_device.clone()]
430    }
431    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
432        if loading_isq {
433            x.to_device(&Device::Cpu)
434        } else {
435            x.to_device(&self.nm_device)
436        }
437    }
438    fn set_nm_device<'a>(
439        &self,
440        varbuilder: ShardedVarBuilder,
441        loading_isq: bool,
442    ) -> ShardedVarBuilder {
443        if loading_isq {
444            varbuilder.set_device(Device::Cpu)
445        } else {
446            varbuilder.set_device(self.nm_device.clone())
447        }
448    }
449    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
450        dtype
451            .try_into_dtype(&[&self.nm_device])
452            .map_err(candle_core::Error::msg)
453    }
454    fn num_device_mapping_layers(&self) -> usize {
455        self.model_layers
456    }
457    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
458        if let Some(comm) = &self.comm {
459            Ok(comm.clone())
460        } else {
461            let id = mistralrs_quant::Id::new();
462            Ok(Arc::new(mistralrs_quant::Comm::from_device(
463                id,
464                self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
465                0,
466                1,
467            )?))
468        }
469    }
470}
471
472#[derive(Debug)]
473/// A device mapper which does device mapping per hidden layer.
474pub struct NcclPipelineParallelMapper {
475    mappings: Vec<(Arc<mistralrs_quant::Comm>, Device)>,
476    nm_device: Device,
477}
478
479impl DeviceMapper for NcclPipelineParallelMapper {
480    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
481        input.to_device(&self.mappings[layer].1)
482    }
483    fn set_device<'a>(
484        &self,
485        layer: usize,
486        varbuilder: ShardedVarBuilder,
487        loading_isq: bool,
488    ) -> ShardedVarBuilder {
489        if loading_isq {
490            return varbuilder;
491        }
492        varbuilder.set_device(self.mappings[layer].1.clone())
493    }
494    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
495        if loading_isq {
496            return Some(&self.nm_device);
497        }
498        self.mappings.get(layer).map(|(_, x)| x)
499    }
500    fn get_unique_devices(&self) -> Vec<Device> {
501        self.mappings
502            .iter()
503            .fold(Vec::new(), |mut acc, (_, device)| {
504                if !acc.iter().any(|d| d.same_device(device)) {
505                    acc.push(device.clone());
506                }
507                acc
508            })
509    }
510    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
511        if loading_isq {
512            x.to_device(&Device::Cpu)
513        } else {
514            x.to_device(&self.nm_device)
515        }
516    }
517    fn set_nm_device<'a>(
518        &self,
519        varbuilder: ShardedVarBuilder,
520        loading_isq: bool,
521    ) -> ShardedVarBuilder {
522        if loading_isq {
523            varbuilder
524        } else {
525            varbuilder.set_device(self.nm_device.clone())
526        }
527    }
528    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
529        dtype
530            .try_into_dtype(&self.mappings.iter().map(|(_, x)| x).collect::<Vec<_>>())
531            .map_err(candle_core::Error::msg)
532    }
533    fn num_device_mapping_layers(&self) -> usize {
534        self.mappings.len()
535    }
536    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
537        Ok(self.mappings[layer_idx].0.clone())
538    }
539}
540
541/// Get all devices on the same device type but different ordinals
542pub fn get_all_similar_devices(base: &Device) -> Result<Vec<Device>> {
543    let mut devices = Vec::new();
544    match base {
545        Device::Cpu => return Ok(vec![Device::Cpu]),
546        Device::Cuda(_) => {
547            let mut ord = 0;
548            let DeviceLocation::Cuda { gpu_id: base_ord } = base.location() else {
549                candle_core::bail!("location and device do not match");
550            };
551            loop {
552                if base_ord == ord {
553                    devices.push(base.clone());
554                    ord += 1;
555                    continue;
556                }
557                // Needs to be without a stream as PagedAttention doesn't like it otherwise.
558                if let Ok(dev) = Device::new_cuda(ord) {
559                    devices.push(dev);
560                    ord += 1;
561                } else {
562                    break;
563                }
564            }
565        }
566        #[cfg(not(feature = "metal"))]
567        Device::Metal(_) => {
568            candle_core::bail!("Not compiled with metal features, but have a metal device.");
569        }
570        #[cfg(feature = "metal")]
571        Device::Metal(_) => {
572            let total_ords = metal::Device::all().len();
573            let mut ord = 0;
574            let DeviceLocation::Metal { gpu_id: base_ord } = base.location() else {
575                candle_core::bail!("location and device do not match");
576            };
577            loop {
578                if base_ord == ord {
579                    devices.push(base.clone());
580                    ord += 1;
581                    continue;
582                }
583                if total_ords == ord {
584                    break;
585                }
586                if let Ok(dev) = Device::new_metal(ord) {
587                    devices.push(dev);
588                    ord += 1;
589                } else {
590                    break;
591                }
592            }
593        }
594    }
595    Ok(devices)
596}