mistralrs_core/pipeline/loaders/
auto_device_map.rs1use std::fmt::{self, Display};
2
3use crate::paged_attention::{
4 calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
5};
6use crate::utils::debug::DeviceRepr;
7use crate::{DeviceLayerMapMetadata, DeviceMapMetadata, MemoryUsage, PagedAttentionConfig};
8use anyhow::{Context, Result};
9use candle_core::{DType, Device};
10use itertools::Itertools;
11use tracing::{info, warn};
12
13use super::DeviceMappedModelLoader;
14
15const GPU_RESERVE_FRACTION: f64 = 0.02;
16const GPU_MIN_RESERVE_BYTES: usize = 512 * 1024 * 1024; #[derive(Clone, Debug)]
19pub(crate) enum NonMappedSubModel {
20 Vision,
21 Audio,
22}
23
24impl Display for NonMappedSubModel {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self {
27 NonMappedSubModel::Vision => write!(f, "vision"),
28 NonMappedSubModel::Audio => write!(f, "audio"),
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
34pub enum AutoDeviceMapParams {
35 Text {
36 max_seq_len: usize,
37 max_batch_size: usize,
38 },
39 Vision {
40 max_seq_len: usize,
41 max_batch_size: usize,
42 max_image_shape: (usize, usize),
43 max_num_images: usize,
44 },
45}
46
47impl AutoDeviceMapParams {
48 pub fn maybe_promote_to_vision(&self) -> Self {
49 match *self {
50 Self::Text {
51 max_seq_len,
52 max_batch_size,
53 } => Self::Vision {
54 max_seq_len,
55 max_batch_size,
56 max_image_shape: (
57 Self::DEFAULT_MAX_IMAGE_LENGTH,
58 Self::DEFAULT_MAX_IMAGE_LENGTH,
59 ),
60 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
61 },
62 Self::Vision {
63 max_seq_len,
64 max_batch_size,
65 max_image_shape,
66 max_num_images,
67 } => Self::Vision {
68 max_seq_len,
69 max_batch_size,
70 max_image_shape,
71 max_num_images,
72 },
73 }
74 }
75
76 pub fn max_seq_len(&self) -> usize {
77 match self {
78 Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
79 }
80 }
81}
82
83impl Display for AutoDeviceMapParams {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 match self {
86 Self::Text {
87 max_seq_len,
88 max_batch_size,
89 } => write!(
90 f,
91 "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
92 ),
93 Self::Vision {
94 max_seq_len,
95 max_batch_size,
96 max_image_shape,
97 max_num_images,
98 } => write!(
99 f,
100 "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
101 ),
102 }
103 }
104}
105
106impl AutoDeviceMapParams {
107 pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
109 pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
110 pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
111 pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
112
113 pub fn default_text() -> Self {
114 Self::Text {
115 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
116 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
117 }
118 }
119
120 pub fn default_vision() -> Self {
121 Self::Vision {
122 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
123 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
124 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
125 max_image_shape: (
126 Self::DEFAULT_MAX_IMAGE_LENGTH,
127 Self::DEFAULT_MAX_IMAGE_LENGTH,
128 ),
129 }
130 }
131}
132
133fn calculate_key_block_shape(
134 model_config: &dyn ModelConfigLike,
135 dtype: DType,
136 block_size: usize,
137) -> (usize, usize, usize, usize) {
138 let element_size = dtype.size_in_bytes();
139 let x = 16 / element_size;
140 (
141 model_config.num_kv_heads(),
142 model_config.k_head_dim() / x,
143 block_size,
144 x,
145 )
146}
147
148fn calculate_value_block_shape(
149 model_config: &dyn ModelConfigLike,
150 block_size: usize,
151) -> (usize, usize, usize) {
152 (
153 model_config.num_kv_heads(),
154 model_config.v_head_dim(),
155 block_size,
156 )
157}
158
159macro_rules! b_to_mb {
160 ($x:expr) => {
161 $x / (1024 * 1024)
162 };
163}
164
165#[allow(clippy::too_many_arguments)]
166pub fn get_device_layers(
168 loader: &dyn DeviceMappedModelLoader,
169 config: &str,
170 num_layers: usize,
171 mut layer_sizes_in_bytes: Vec<usize>,
172 non_mapped_size_in_bytes: usize,
173 total_model_size_in_bytes: usize,
174 devices: &[Device],
175 dtype: DType,
176 params: &AutoDeviceMapParams,
177 paged_attn_config: Option<&PagedAttentionConfig>,
178) -> Result<DeviceMapMetadata> {
179 let mapped_max = loader.mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
180 let non_mapped_max =
181 loader.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
182
183 let mut layer_sizes_backup = if paged_attn_config.is_some() {
184 Some(layer_sizes_in_bytes.clone())
185 } else {
186 None
187 };
188
189 let mut remaining = total_model_size_in_bytes;
190 let max_seq_len = match params {
191 AutoDeviceMapParams::Text { max_seq_len, .. }
192 | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
193 };
194 let max_batch_size = match params {
195 AutoDeviceMapParams::Text { max_batch_size, .. }
196 | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
197 };
198
199 let model_cfg = loader.model_config(config)?;
200 let kv_cache_elems = match paged_attn_config {
201 Some(cfg) => {
202 let cache = calculate_cache_config(
203 cfg.mem_gpu,
204 Some(cfg.block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE)),
205 dtype,
206 paged_attn_config
207 .map(|cfg| cfg.cache_type)
208 .unwrap_or_default(),
209 &*model_cfg,
210 &devices[0],
211 &devices.iter().map(|d| Some(d.clone())).collect::<Vec<_>>(),
212 true,
213 )?;
214 let key_shape = calculate_key_block_shape(&*model_cfg, dtype, cache.block_size);
215 let key_sz =
216 cache.num_gpu_blocks * key_shape.0 * key_shape.1 * key_shape.2 * key_shape.3;
217 let val_shape = calculate_value_block_shape(&*model_cfg, cache.block_size);
218 let val_sz = cache.num_gpu_blocks * val_shape.0 * val_shape.1 * val_shape.2;
219 key_sz + val_sz
220 }
221 None => {
222 let key_shape = [
223 max_batch_size,
224 model_cfg.num_kv_heads(),
225 max_seq_len,
226 model_cfg.k_head_dim(),
227 ];
228 let val_shape = [
229 max_batch_size,
230 model_cfg.num_kv_heads(),
231 max_seq_len,
232 model_cfg.v_head_dim(),
233 ];
234 key_shape.iter().product::<usize>() + val_shape.iter().product::<usize>()
235 }
236 };
237 let kv_cache_bytes = kv_cache_elems * dtype.size_in_bytes();
238
239 let mut avail = Vec::new();
241 for dev in [devices, &[Device::Cpu]].concat() {
242 let a = MemoryUsage.get_memory_available(&dev)?;
243 avail.push((a, dev));
244 }
245 avail.reverse();
246 layer_sizes_in_bytes.reverse();
247
248 let mut mappings = Vec::new();
249 info!("Using automatic device mapping parameters: {params}.");
250 if let Some(subs) = loader.non_mapped_sub_models() {
251 let (_, last) = avail.last().unwrap();
252 info!(
253 "The following sub-models will not be device mapped and will be loaded on {}: {}",
254 last.device_pretty_repr(),
255 subs.iter().map(|x| x.to_string()).join(", ")
256 );
257 }
258
259 let mut ordinal = 0;
260 let mut layer = 0;
261 let avail_copy = avail.clone();
262 let mut includes_cpu = false;
263 while remaining > 0 && !avail.is_empty() {
264 let (avail_bytes, dev) = avail
265 .pop()
266 .context("No more devices to map to. The model does not fit on this system.")?;
267
268 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
271 let cap = if dev.is_cpu() {
272 usize::MAX
274 } else {
275 let reserve_fraction = (avail_bytes as f64 * GPU_RESERVE_FRACTION) as usize;
276 let reserve = reserve_fraction.max(GPU_MIN_RESERVE_BYTES).min(avail_bytes);
277 avail_bytes.saturating_sub(reserve)
278 };
279
280 let required_whole_capacity = if ordinal == 0 {
287 remaining + non_mapped_max.max(mapped_max) + kv_cache_bytes * (num_layers - layer)
288 } else {
289 remaining + mapped_max + kv_cache_bytes * (num_layers - layer)
290 };
291
292 let layers_on_dev = if cap >= required_whole_capacity {
293 remaining = 0;
294 num_layers - layer
295 } else {
296 let mut used = mapped_max;
297 let mut used_weight_bytes = 0;
298 let mut count = 0;
299 if ordinal == 0 {
300 used = used.max(non_mapped_max) + non_mapped_size_in_bytes;
301 used_weight_bytes += non_mapped_size_in_bytes;
302 }
303 while let Some(&sz) = layer_sizes_in_bytes.last() {
304 let delta = sz + kv_cache_bytes;
305 if used + delta > cap {
306 break;
307 }
308 layer_sizes_in_bytes.pop();
309 used += delta;
310 used_weight_bytes += sz;
311 count += 1;
312 }
313 if count > 0 {
314 remaining = remaining.saturating_sub(used_weight_bytes);
315 } else {
316 warn!(
317 "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
318 dev.device_pretty_repr(),
319 );
320 ordinal += 1;
321 continue;
322 }
323 count
324 };
325 if !dev.is_cpu() {
326 mappings.push(DeviceLayerMapMetadata {
327 ordinal,
328 layers: layers_on_dev,
329 });
330 ordinal += 1;
331 } else {
332 includes_cpu = true;
333 }
334 layer += layers_on_dev;
335 }
336 if remaining > 0 {
337 let over = b_to_mb!(remaining);
338 anyhow::bail!(
339 "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
340 avail_copy.iter().rev().map(|(a, d)| format!("{} (avail: {}MB)", d.device_pretty_repr(), b_to_mb!(a))).collect::<Vec<_>>(),
341 over
342 );
343 }
344 if paged_attn_config.is_some_and(|_| includes_cpu) {
345 let original_layers = layer_sizes_backup
346 .take()
347 .expect("layer sizes backup missing for paged attention fallback");
348 return get_device_layers(
351 loader,
352 config,
353 num_layers,
354 original_layers,
355 non_mapped_size_in_bytes,
356 total_model_size_in_bytes,
357 devices,
358 dtype,
359 params,
360 None,
361 );
362 }
363 Ok(DeviceMapMetadata::from_num_device_layers(mappings))
364}