mistralrs_core/pipeline/loaders/
auto_device_map.rs

1use std::fmt::{self, Display};
2
3use crate::paged_attention::{
4    calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
5};
6use crate::utils::debug::DeviceRepr;
7use crate::{DeviceLayerMapMetadata, DeviceMapMetadata, MemoryUsage, PagedAttentionConfig};
8use anyhow::{Context, Result};
9use candle_core::{DType, Device};
10use itertools::Itertools;
11use tracing::{info, warn};
12
13use super::DeviceMappedModelLoader;
14
15const GPU_RESERVE_FRACTION: f64 = 0.02;
16const GPU_MIN_RESERVE_BYTES: usize = 512 * 1024 * 1024; // 512MB safety buffer
17
18#[derive(Clone, Debug)]
19pub(crate) enum NonMappedSubModel {
20    Vision,
21    Audio,
22}
23
24impl Display for NonMappedSubModel {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self {
27            NonMappedSubModel::Vision => write!(f, "vision"),
28            NonMappedSubModel::Audio => write!(f, "audio"),
29        }
30    }
31}
32
33#[derive(Debug, Clone)]
34pub enum AutoDeviceMapParams {
35    Text {
36        max_seq_len: usize,
37        max_batch_size: usize,
38    },
39    Vision {
40        max_seq_len: usize,
41        max_batch_size: usize,
42        max_image_shape: (usize, usize),
43        max_num_images: usize,
44    },
45}
46
47impl AutoDeviceMapParams {
48    pub fn maybe_promote_to_vision(&self) -> Self {
49        match *self {
50            Self::Text {
51                max_seq_len,
52                max_batch_size,
53            } => Self::Vision {
54                max_seq_len,
55                max_batch_size,
56                max_image_shape: (
57                    Self::DEFAULT_MAX_IMAGE_LENGTH,
58                    Self::DEFAULT_MAX_IMAGE_LENGTH,
59                ),
60                max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
61            },
62            Self::Vision {
63                max_seq_len,
64                max_batch_size,
65                max_image_shape,
66                max_num_images,
67            } => Self::Vision {
68                max_seq_len,
69                max_batch_size,
70                max_image_shape,
71                max_num_images,
72            },
73        }
74    }
75
76    pub fn max_seq_len(&self) -> usize {
77        match self {
78            Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
79        }
80    }
81}
82
83impl Display for AutoDeviceMapParams {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        match self {
86            Self::Text {
87                max_seq_len,
88                max_batch_size,
89            } => write!(
90                f,
91                "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
92            ),
93            Self::Vision {
94                max_seq_len,
95                max_batch_size,
96                max_image_shape,
97                max_num_images,
98            } => write!(
99                f,
100                "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
101            ),
102        }
103    }
104}
105
106impl AutoDeviceMapParams {
107    // Default max sequence length for memory estimation when not specified
108    pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
109    pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
110    pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
111    pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
112
113    pub fn default_text() -> Self {
114        Self::Text {
115            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
116            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
117        }
118    }
119
120    pub fn default_vision() -> Self {
121        Self::Vision {
122            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
123            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
124            max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
125            max_image_shape: (
126                Self::DEFAULT_MAX_IMAGE_LENGTH,
127                Self::DEFAULT_MAX_IMAGE_LENGTH,
128            ),
129        }
130    }
131}
132
133fn calculate_key_block_shape(
134    model_config: &dyn ModelConfigLike,
135    dtype: DType,
136    block_size: usize,
137) -> (usize, usize, usize, usize) {
138    let element_size = dtype.size_in_bytes();
139    let x = 16 / element_size;
140    (
141        model_config.num_kv_heads(),
142        model_config.k_head_dim() / x,
143        block_size,
144        x,
145    )
146}
147
148fn calculate_value_block_shape(
149    model_config: &dyn ModelConfigLike,
150    block_size: usize,
151) -> (usize, usize, usize) {
152    (
153        model_config.num_kv_heads(),
154        model_config.v_head_dim(),
155        block_size,
156    )
157}
158
159macro_rules! b_to_mb {
160    ($x:expr) => {
161        $x / (1024 * 1024)
162    };
163}
164
165#[allow(clippy::too_many_arguments)]
166/// Core logic for automatic device mapping
167pub fn get_device_layers(
168    loader: &dyn DeviceMappedModelLoader,
169    config: &str,
170    num_layers: usize,
171    mut layer_sizes_in_bytes: Vec<usize>,
172    non_mapped_size_in_bytes: usize,
173    total_model_size_in_bytes: usize,
174    devices: &[Device],
175    dtype: DType,
176    params: &AutoDeviceMapParams,
177    paged_attn_config: Option<&PagedAttentionConfig>,
178) -> Result<DeviceMapMetadata> {
179    let mapped_max = loader.mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
180    let non_mapped_max =
181        loader.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
182
183    let mut layer_sizes_backup = if paged_attn_config.is_some() {
184        Some(layer_sizes_in_bytes.clone())
185    } else {
186        None
187    };
188
189    let mut remaining = total_model_size_in_bytes;
190    let max_seq_len = match params {
191        AutoDeviceMapParams::Text { max_seq_len, .. }
192        | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
193    };
194    let max_batch_size = match params {
195        AutoDeviceMapParams::Text { max_batch_size, .. }
196        | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
197    };
198
199    let model_cfg = loader.model_config(config)?;
200    let kv_cache_elems = match paged_attn_config {
201        Some(cfg) => {
202            let cache = calculate_cache_config(
203                cfg.mem_gpu,
204                Some(cfg.block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE)),
205                dtype,
206                paged_attn_config
207                    .map(|cfg| cfg.cache_type)
208                    .unwrap_or_default(),
209                &*model_cfg,
210                &devices[0],
211                &devices.iter().map(|d| Some(d.clone())).collect::<Vec<_>>(),
212                true,
213            )?;
214            let key_shape = calculate_key_block_shape(&*model_cfg, dtype, cache.block_size);
215            let key_sz =
216                cache.num_gpu_blocks * key_shape.0 * key_shape.1 * key_shape.2 * key_shape.3;
217            let val_shape = calculate_value_block_shape(&*model_cfg, cache.block_size);
218            let val_sz = cache.num_gpu_blocks * val_shape.0 * val_shape.1 * val_shape.2;
219            key_sz + val_sz
220        }
221        None => {
222            let key_shape = [
223                max_batch_size,
224                model_cfg.num_kv_heads(),
225                max_seq_len,
226                model_cfg.k_head_dim(),
227            ];
228            let val_shape = [
229                max_batch_size,
230                model_cfg.num_kv_heads(),
231                max_seq_len,
232                model_cfg.v_head_dim(),
233            ];
234            key_shape.iter().product::<usize>() + val_shape.iter().product::<usize>()
235        }
236    };
237    let kv_cache_bytes = kv_cache_elems * dtype.size_in_bytes();
238
239    // prepare available memory per device, CPU fallback last
240    let mut avail = Vec::new();
241    for dev in [devices, &[Device::Cpu]].concat() {
242        let a = MemoryUsage.get_memory_available(&dev)?;
243        avail.push((a, dev));
244    }
245    avail.reverse();
246    layer_sizes_in_bytes.reverse();
247
248    let mut mappings = Vec::new();
249    info!("Using automatic device mapping parameters: {params}.");
250    if let Some(subs) = loader.non_mapped_sub_models() {
251        let (_, last) = avail.last().unwrap();
252        info!(
253            "The following sub-models will not be device mapped and will be loaded on {}: {}",
254            last.device_pretty_repr(),
255            subs.iter().map(|x| x.to_string()).join(", ")
256        );
257    }
258
259    let mut ordinal = 0;
260    let mut layer = 0;
261    let avail_copy = avail.clone();
262    let mut includes_cpu = false;
263    while remaining > 0 && !avail.is_empty() {
264        let (avail_bytes, dev) = avail
265            .pop()
266            .context("No more devices to map to. The model does not fit on this system.")?;
267
268        // For CPU: effectively unlimited capacity since it can use swap memory
269        // For GPU/accelerators: keep a small dynamic safety reserve to avoid OOMs
270        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
271        let cap = if dev.is_cpu() {
272            // Allow unlimited capacity for CPU - swap will handle it
273            usize::MAX
274        } else {
275            let reserve_fraction = (avail_bytes as f64 * GPU_RESERVE_FRACTION) as usize;
276            let reserve = reserve_fraction.max(GPU_MIN_RESERVE_BYTES).min(avail_bytes);
277            avail_bytes.saturating_sub(reserve)
278        };
279
280        // Algorithm is to check the following:
281        // 1) (no mapping) if *everything* fits on the first dev (non mapped and mapped)
282        // 2) if the mapped activations plus remaining fits on the nth device
283        // 3) common case, iteratively find the optimal amount of layers to put on the nth device
284        //   - if this is the first dev: must hold the non-mapped act and non-mapped model
285        //   - otherwise, must hold the mapped act
286        let required_whole_capacity = if ordinal == 0 {
287            remaining + non_mapped_max.max(mapped_max) + kv_cache_bytes * (num_layers - layer)
288        } else {
289            remaining + mapped_max + kv_cache_bytes * (num_layers - layer)
290        };
291
292        let layers_on_dev = if cap >= required_whole_capacity {
293            remaining = 0;
294            num_layers - layer
295        } else {
296            let mut used = mapped_max;
297            let mut used_weight_bytes = 0;
298            let mut count = 0;
299            if ordinal == 0 {
300                used = used.max(non_mapped_max) + non_mapped_size_in_bytes;
301                used_weight_bytes += non_mapped_size_in_bytes;
302            }
303            while let Some(&sz) = layer_sizes_in_bytes.last() {
304                let delta = sz + kv_cache_bytes;
305                if used + delta > cap {
306                    break;
307                }
308                layer_sizes_in_bytes.pop();
309                used += delta;
310                used_weight_bytes += sz;
311                count += 1;
312            }
313            if count > 0 {
314                remaining = remaining.saturating_sub(used_weight_bytes);
315            } else {
316                warn!(
317                    "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
318                    dev.device_pretty_repr(),
319                );
320                ordinal += 1;
321                continue;
322            }
323            count
324        };
325        if !dev.is_cpu() {
326            mappings.push(DeviceLayerMapMetadata {
327                ordinal,
328                layers: layers_on_dev,
329            });
330            ordinal += 1;
331        } else {
332            includes_cpu = true;
333        }
334        layer += layers_on_dev;
335    }
336    if remaining > 0 {
337        let over = b_to_mb!(remaining);
338        anyhow::bail!(
339            "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
340            avail_copy.iter().rev().map(|(a, d)| format!("{} (avail: {}MB)", d.device_pretty_repr(), b_to_mb!(a))).collect::<Vec<_>>(),
341            over
342        );
343    }
344    if paged_attn_config.is_some_and(|_| includes_cpu) {
345        let original_layers = layer_sizes_backup
346            .take()
347            .expect("layer sizes backup missing for paged attention fallback");
348        // The original vector was in forward order, but `get_device_layers` handles
349        // reversing internally, so we can pass it along unchanged.
350        return get_device_layers(
351            loader,
352            config,
353            num_layers,
354            original_layers,
355            non_mapped_size_in_bytes,
356            total_model_size_in_bytes,
357            devices,
358            dtype,
359            params,
360            None,
361        );
362    }
363    Ok(DeviceMapMetadata::from_num_device_layers(mappings))
364}