mistralrs_core/pipeline/loaders/
auto_device_map.rs

1use std::fmt::{self, Display};
2
3use crate::paged_attention::{
4    calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
5};
6use crate::utils::debug::DeviceRepr;
7use crate::{DeviceLayerMapMetadata, DeviceMapMetadata, MemoryUsage, PagedAttentionConfig};
8use anyhow::{Context, Result};
9use candle_core::{DType, Device};
10use itertools::Itertools;
11use tracing::{info, warn};
12
13use super::DeviceMappedModelLoader;
14
15#[derive(Clone, Debug)]
16pub(crate) enum NonMappedSubModel {
17    Vision,
18    Audio,
19}
20
21impl Display for NonMappedSubModel {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        match self {
24            NonMappedSubModel::Vision => write!(f, "vision"),
25            NonMappedSubModel::Audio => write!(f, "audio"),
26        }
27    }
28}
29
30#[derive(Debug, Clone)]
31pub enum AutoDeviceMapParams {
32    Text {
33        max_seq_len: usize,
34        max_batch_size: usize,
35    },
36    Vision {
37        max_seq_len: usize,
38        max_batch_size: usize,
39        max_image_shape: (usize, usize),
40        max_num_images: usize,
41    },
42}
43
44impl AutoDeviceMapParams {
45    pub fn maybe_promote_to_vision(&self) -> Self {
46        match *self {
47            Self::Text {
48                max_seq_len,
49                max_batch_size,
50            } => Self::Vision {
51                max_seq_len,
52                max_batch_size,
53                max_image_shape: (
54                    Self::DEFAULT_MAX_IMAGE_LENGTH,
55                    Self::DEFAULT_MAX_IMAGE_LENGTH,
56                ),
57                max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
58            },
59            Self::Vision {
60                max_seq_len,
61                max_batch_size,
62                max_image_shape,
63                max_num_images,
64            } => Self::Vision {
65                max_seq_len,
66                max_batch_size,
67                max_image_shape,
68                max_num_images,
69            },
70        }
71    }
72
73    pub fn max_seq_len(&self) -> usize {
74        match self {
75            Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
76        }
77    }
78}
79
80impl Display for AutoDeviceMapParams {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        match self {
83            Self::Text {
84                max_seq_len,
85                max_batch_size,
86            } => write!(
87                f,
88                "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
89            ),
90            Self::Vision {
91                max_seq_len,
92                max_batch_size,
93                max_image_shape,
94                max_num_images,
95            } => write!(
96                f,
97                "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
98            ),
99        }
100    }
101}
102
103impl AutoDeviceMapParams {
104    // Default max sequence length for memory estimation when not specified
105    pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
106    pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
107    pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
108    pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
109
110    pub fn default_text() -> Self {
111        Self::Text {
112            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
113            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
114        }
115    }
116
117    pub fn default_vision() -> Self {
118        Self::Vision {
119            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
120            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
121            max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
122            max_image_shape: (
123                Self::DEFAULT_MAX_IMAGE_LENGTH,
124                Self::DEFAULT_MAX_IMAGE_LENGTH,
125            ),
126        }
127    }
128}
129
130fn calculate_key_block_shape(
131    model_config: &dyn ModelConfigLike,
132    dtype: DType,
133    block_size: usize,
134) -> (usize, usize, usize, usize) {
135    let element_size = dtype.size_in_bytes();
136    let x = 16 / element_size;
137    (
138        model_config.num_kv_heads(),
139        model_config.k_head_dim() / x,
140        block_size,
141        x,
142    )
143}
144
145fn calculate_value_block_shape(
146    model_config: &dyn ModelConfigLike,
147    block_size: usize,
148) -> (usize, usize, usize) {
149    (
150        model_config.num_kv_heads(),
151        model_config.v_head_dim(),
152        block_size,
153    )
154}
155
156macro_rules! b_to_mb {
157    ($x:expr) => {
158        $x / (1024 * 1024)
159    };
160}
161
162#[allow(clippy::too_many_arguments)]
163/// Core logic for automatic device mapping
164pub fn get_device_layers(
165    loader: &dyn DeviceMappedModelLoader,
166    config: &str,
167    num_layers: usize,
168    mut layer_sizes_in_bytes: Vec<usize>,
169    non_mapped_size_in_bytes: usize,
170    total_model_size_in_bytes: usize,
171    devices: &[Device],
172    dtype: DType,
173    params: &AutoDeviceMapParams,
174    paged_attn_config: Option<&PagedAttentionConfig>,
175) -> Result<DeviceMapMetadata> {
176    let mapped_max = loader.mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
177    let non_mapped_max =
178        loader.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
179
180    let mut remaining = total_model_size_in_bytes;
181    let max_seq_len = match params {
182        AutoDeviceMapParams::Text { max_seq_len, .. }
183        | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
184    };
185    let max_batch_size = match params {
186        AutoDeviceMapParams::Text { max_batch_size, .. }
187        | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
188    };
189
190    let model_cfg = loader.model_config(config)?;
191    let kv_cache_elems = match paged_attn_config {
192        Some(cfg) => {
193            let cache = calculate_cache_config(
194                cfg.mem_gpu,
195                Some(cfg.block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE)),
196                dtype,
197                paged_attn_config
198                    .map(|cfg| cfg.cache_type)
199                    .unwrap_or_default(),
200                &*model_cfg,
201                &devices[0],
202                &devices.iter().map(|d| Some(d.clone())).collect::<Vec<_>>(),
203                true,
204            )?;
205            let key_shape = calculate_key_block_shape(&*model_cfg, dtype, cache.block_size);
206            let key_sz =
207                cache.num_gpu_blocks * key_shape.0 * key_shape.1 * key_shape.2 * key_shape.3;
208            let val_shape = calculate_value_block_shape(&*model_cfg, cache.block_size);
209            let val_sz = cache.num_gpu_blocks * val_shape.0 * val_shape.1 * val_shape.2;
210            key_sz + val_sz
211        }
212        None => {
213            let key_shape = [
214                max_batch_size,
215                model_cfg.num_kv_heads(),
216                max_seq_len,
217                model_cfg.k_head_dim(),
218            ];
219            let val_shape = [
220                max_batch_size,
221                model_cfg.num_kv_heads(),
222                max_seq_len,
223                model_cfg.v_head_dim(),
224            ];
225            key_shape.iter().product::<usize>() + val_shape.iter().product::<usize>()
226        }
227    };
228    let kv_cache_bytes = kv_cache_elems * dtype.size_in_bytes();
229
230    // prepare available memory per device, CPU fallback last
231    let mut avail = Vec::new();
232    for dev in [devices, &[Device::Cpu]].concat() {
233        let a = MemoryUsage.get_memory_available(&dev)?;
234        avail.push((a, dev));
235    }
236    avail.reverse();
237    layer_sizes_in_bytes.reverse();
238
239    let mut mappings = Vec::new();
240    info!("Using automatic device mapping parameters: {params}.");
241    if let Some(subs) = loader.non_mapped_sub_models() {
242        let (_, last) = avail.last().unwrap();
243        info!(
244            "The following sub-models will not be device mapped and will be loaded on {}: {}",
245            last.device_pretty_repr(),
246            subs.iter().map(|x| x.to_string()).join(", ")
247        );
248    }
249
250    let mut ordinal = 0;
251    let mut layer = 0;
252    let avail_copy = avail.clone();
253    let mut includes_cpu = false;
254    while remaining > 0 && !avail.is_empty() {
255        let (cap, dev) = avail
256            .pop()
257            .context("No more devices to map to. The model does not fit on this system.")?;
258
259        // For CPU: effectively unlimited capacity since it can use swap memory
260        // For GPU/accelerators: use 90% of available memory as maximum
261        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
262        let cap = if dev.is_cpu() {
263            // Allow unlimited capacity for CPU - swap will handle it
264            usize::MAX
265        } else {
266            (cap as f64 * 0.90) as usize
267        };
268
269        // Algorithm is to check the following:
270        // 1) (no mapping) if *everything* fits on the first dev (non mapped and mapped)
271        // 2) if the mapped activations plus remaining fits on the nth device
272        // 3) common case, iteratively find the optimal amount of layers to put on the nth device
273        //   - if this is the first dev: must hold the non-mapped act and non-mapped model
274        //   - otherwise, must hold the mapped act
275        let required_whole_capacity = if ordinal == 0 {
276            remaining
277                + non_mapped_max.max(mapped_max)
278                + non_mapped_size_in_bytes
279                + kv_cache_bytes * (num_layers - layer)
280        } else {
281            remaining + mapped_max + kv_cache_bytes * (num_layers - layer)
282        };
283
284        let layers_on_dev = if cap >= required_whole_capacity {
285            remaining = 0;
286            num_layers - layer
287        } else {
288            let mut used = mapped_max;
289            let mut used_no_act = 0;
290            let mut count = 0;
291            if ordinal == 0 {
292                used = used.max(non_mapped_max) + non_mapped_size_in_bytes;
293                used_no_act += non_mapped_size_in_bytes;
294            }
295            while let Some(&sz) = layer_sizes_in_bytes.last() {
296                let delta = sz + kv_cache_bytes;
297                if used + delta > cap {
298                    break;
299                }
300                layer_sizes_in_bytes.pop();
301                used += delta;
302                used_no_act += delta;
303                count += 1;
304            }
305            if count > 0 {
306                remaining = remaining.saturating_sub(used_no_act);
307            } else {
308                warn!(
309                    "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
310                    dev.device_pretty_repr(),
311                );
312                ordinal += 1;
313                continue;
314            }
315            count
316        };
317        if !dev.is_cpu() {
318            mappings.push(DeviceLayerMapMetadata {
319                ordinal,
320                layers: layers_on_dev,
321            });
322            ordinal += 1;
323        } else {
324            includes_cpu = true;
325        }
326        layer += layers_on_dev;
327    }
328    if remaining > 0 {
329        let over = b_to_mb!(remaining);
330        anyhow::bail!(
331            "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
332            avail_copy.iter().rev().map(|(a, d)| format!("{} (avail: {}MB)", d.device_pretty_repr(), b_to_mb!(a))).collect::<Vec<_>>(),
333            over
334        );
335    }
336    if paged_attn_config.is_some_and(|_| includes_cpu) {
337        return get_device_layers(
338            loader,
339            config,
340            num_layers,
341            layer_sizes_in_bytes,
342            non_mapped_size_in_bytes,
343            total_model_size_in_bytes,
344            devices,
345            dtype,
346            params,
347            None,
348        );
349    }
350    Ok(DeviceMapMetadata::from_num_device_layers(mappings))
351}