mistralrs_core/pipeline/loaders/
auto_device_map.rs

1use std::fmt::{self, Display};
2
3use crate::paged_attention::{
4    calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
5};
6use crate::utils::debug::DeviceRepr;
7use crate::{DeviceLayerMapMetadata, DeviceMapMetadata, MemoryUsage, PagedAttentionConfig};
8use anyhow::{Context, Result};
9use candle_core::{DType, Device};
10use itertools::Itertools;
11use tracing::{info, warn};
12
13use super::DeviceMappedModelLoader;
14
15#[derive(Clone, Debug)]
16pub(crate) enum NonMappedSubModel {
17    Vision,
18}
19
20impl Display for NonMappedSubModel {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        match self {
23            NonMappedSubModel::Vision => write!(f, "vision"),
24        }
25    }
26}
27
28#[derive(Debug, Clone)]
29pub enum AutoDeviceMapParams {
30    Text {
31        max_seq_len: usize,
32        max_batch_size: usize,
33    },
34    Vision {
35        max_seq_len: usize,
36        max_batch_size: usize,
37        max_image_shape: (usize, usize),
38        max_num_images: usize,
39    },
40}
41
42impl AutoDeviceMapParams {
43    pub fn maybe_promote_to_vision(&self) -> Self {
44        match *self {
45            Self::Text {
46                max_seq_len,
47                max_batch_size,
48            } => Self::Vision {
49                max_seq_len,
50                max_batch_size,
51                max_image_shape: (
52                    Self::DEFAULT_MAX_IMAGE_LENGTH,
53                    Self::DEFAULT_MAX_IMAGE_LENGTH,
54                ),
55                max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
56            },
57            Self::Vision {
58                max_seq_len,
59                max_batch_size,
60                max_image_shape,
61                max_num_images,
62            } => Self::Vision {
63                max_seq_len,
64                max_batch_size,
65                max_image_shape,
66                max_num_images,
67            },
68        }
69    }
70
71    pub fn max_seq_len(&self) -> usize {
72        match self {
73            Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
74        }
75    }
76}
77
78impl Display for AutoDeviceMapParams {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            Self::Text {
82                max_seq_len,
83                max_batch_size,
84            } => write!(
85                f,
86                "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
87            ),
88            Self::Vision {
89                max_seq_len,
90                max_batch_size,
91                max_image_shape,
92                max_num_images,
93            } => write!(
94                f,
95                "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
96            ),
97        }
98    }
99}
100
101impl AutoDeviceMapParams {
102    pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
103    pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
104    pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
105    pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
106
107    pub fn default_text() -> Self {
108        Self::Text {
109            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
110            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
111        }
112    }
113
114    pub fn default_vision() -> Self {
115        Self::Vision {
116            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
117            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
118            max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
119            max_image_shape: (
120                Self::DEFAULT_MAX_IMAGE_LENGTH,
121                Self::DEFAULT_MAX_IMAGE_LENGTH,
122            ),
123        }
124    }
125}
126
127fn calculate_key_block_shape(
128    model_config: &dyn ModelConfigLike,
129    dtype: DType,
130    block_size: usize,
131) -> (usize, usize, usize, usize) {
132    let element_size = dtype.size_in_bytes();
133    let x = 16 / element_size;
134    (
135        model_config.num_kv_heads(),
136        model_config.k_head_dim() / x,
137        block_size,
138        x,
139    )
140}
141
142fn calculate_value_block_shape(
143    model_config: &dyn ModelConfigLike,
144    block_size: usize,
145) -> (usize, usize, usize) {
146    (
147        model_config.num_kv_heads(),
148        model_config.v_head_dim(),
149        block_size,
150    )
151}
152
153macro_rules! b_to_mb {
154    ($x:expr) => {
155        $x / (1024 * 1024)
156    };
157}
158
159#[allow(clippy::too_many_arguments)]
160/// Core logic for automatic device mapping
161pub fn get_device_layers(
162    loader: &dyn DeviceMappedModelLoader,
163    config: &str,
164    num_layers: usize,
165    mut layer_sizes_in_bytes: Vec<usize>,
166    non_mapped_size_in_bytes: usize,
167    total_model_size_in_bytes: usize,
168    devices: &[Device],
169    dtype: DType,
170    params: &AutoDeviceMapParams,
171    prompt_chunksize: usize,
172    paged_attn_config: Option<&PagedAttentionConfig>,
173) -> Result<DeviceMapMetadata> {
174    let mapped_max =
175        loader.mapped_max_act_size_elems(config, params, prompt_chunksize)? * dtype.size_in_bytes();
176    let non_mapped_max =
177        loader.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
178
179    let mut remaining = total_model_size_in_bytes;
180    let max_seq_len = match params {
181        AutoDeviceMapParams::Text { max_seq_len, .. }
182        | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
183    };
184    let max_batch_size = match params {
185        AutoDeviceMapParams::Text { max_batch_size, .. }
186        | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
187    };
188
189    let model_cfg = loader.model_config(config)?;
190    let kv_cache_elems = match paged_attn_config {
191        Some(cfg) => {
192            let cache = calculate_cache_config(
193                cfg.mem_gpu,
194                cfg.mem_cpu,
195                Some(cfg.block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE)),
196                dtype,
197                &*model_cfg,
198                &devices[0],
199                &devices.iter().map(|d| Some(d.clone())).collect::<Vec<_>>(),
200                true,
201            )?;
202            let key_shape = calculate_key_block_shape(&*model_cfg, dtype, cache.block_size);
203            let key_sz =
204                cache.num_gpu_blocks * key_shape.0 * key_shape.1 * key_shape.2 * key_shape.3;
205            let val_shape = calculate_value_block_shape(&*model_cfg, cache.block_size);
206            let val_sz = cache.num_gpu_blocks * val_shape.0 * val_shape.1 * val_shape.2;
207            key_sz + val_sz
208        }
209        None => {
210            let key_shape = [
211                max_batch_size,
212                model_cfg.num_kv_heads(),
213                max_seq_len,
214                model_cfg.k_head_dim(),
215            ];
216            let val_shape = [
217                max_batch_size,
218                model_cfg.num_kv_heads(),
219                max_seq_len,
220                model_cfg.v_head_dim(),
221            ];
222            key_shape.iter().product::<usize>() + val_shape.iter().product::<usize>()
223        }
224    };
225    let kv_cache_bytes = kv_cache_elems * dtype.size_in_bytes();
226
227    // prepare available memory per device, CPU fallback last
228    let mut avail = Vec::new();
229    for dev in [devices, &[Device::Cpu]].concat() {
230        let a = MemoryUsage.get_memory_available(&dev)?;
231        avail.push((a, dev));
232    }
233    avail.reverse();
234    layer_sizes_in_bytes.reverse();
235
236    let mut mappings = Vec::new();
237    info!("Using automatic device mapping parameters: {params}.");
238    if let Some(subs) = loader.non_mapped_sub_models() {
239        let (_, last) = avail.last().unwrap();
240        info!(
241            "The following sub-models will not be device mapped and will be loaded on {}: {}",
242            last.device_pretty_repr(),
243            subs.iter().map(|x| x.to_string()).join(", ")
244        );
245    }
246
247    let mut ordinal = 0;
248    let mut layer = 0;
249    let avail_copy = avail.clone();
250    let mut includes_cpu = false;
251    while remaining > 0 && !avail.is_empty() {
252        let (cap, dev) = avail
253            .pop()
254            .context("No more devices to map to. The model does not fit on this system.")?;
255
256        // All usage of 90% of the memory as a maximum.
257        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
258        let cap = (cap as f64 * 0.90) as usize;
259
260        // Algorithm is to check the following:
261        // 1) (no mapping) if *everything* fits on the first dev (non mapped and mapped)
262        // 2) if the mapped activations plus remaining fits on the nth device
263        // 3) common case, iteratively find the optimal amount of layers to put on the nth device
264        //   - if this is the first dev: must hold the non-mapped act and non-mapped model
265        //   - otherwise, must hold the mapped act
266        let required_whole_capacity = if ordinal == 0 {
267            remaining
268                + non_mapped_max.max(mapped_max)
269                + non_mapped_size_in_bytes
270                + kv_cache_bytes * (num_layers - layer)
271        } else {
272            remaining + mapped_max + kv_cache_bytes * (num_layers - layer)
273        };
274
275        let layers_on_dev = if cap >= required_whole_capacity {
276            remaining = 0;
277            num_layers - layer
278        } else {
279            let mut used = mapped_max;
280            let mut used_no_act = 0;
281            let mut count = 0;
282            if ordinal == 0 {
283                used = used.max(non_mapped_max) + non_mapped_size_in_bytes;
284                used_no_act += non_mapped_size_in_bytes;
285            }
286            while let Some(&sz) = layer_sizes_in_bytes.last() {
287                let delta = sz + kv_cache_bytes;
288                if used + delta > cap {
289                    break;
290                }
291                layer_sizes_in_bytes.pop();
292                used += delta;
293                used_no_act += delta;
294                count += 1;
295            }
296            if count > 0 {
297                remaining = remaining.saturating_sub(used_no_act);
298            } else {
299                warn!(
300                    "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
301                    dev.device_pretty_repr(),
302                );
303                ordinal += 1;
304                continue;
305            }
306            count
307        };
308        if !dev.is_cpu() {
309            mappings.push(DeviceLayerMapMetadata {
310                ordinal,
311                layers: layers_on_dev,
312            });
313            ordinal += 1;
314        } else {
315            includes_cpu = true;
316        }
317        layer += layers_on_dev;
318    }
319    if remaining > 0 {
320        let over = b_to_mb!(remaining);
321        anyhow::bail!(
322            "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
323            avail_copy.iter().rev().map(|(a, d)| format!("{} (avail: {}MB)", d.device_pretty_repr(), b_to_mb!(a))).collect::<Vec<_>>(),
324            over
325        );
326    }
327    if paged_attn_config.is_some_and(|_| includes_cpu) {
328        return get_device_layers(
329            loader,
330            config,
331            num_layers,
332            layer_sizes_in_bytes,
333            non_mapped_size_in_bytes,
334            total_model_size_in_bytes,
335            devices,
336            dtype,
337            params,
338            prompt_chunksize,
339            None,
340        );
341    }
342    Ok(DeviceMapMetadata::from_num_device_layers(mappings))
343}