mistralrs_core/pipeline/loaders/
auto_device_map.rs1use std::fmt::{self, Display};
2
3use crate::paged_attention::{
4 calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
5};
6use crate::utils::debug::DeviceRepr;
7use crate::{DeviceLayerMapMetadata, DeviceMapMetadata, MemoryUsage, PagedAttentionConfig};
8use anyhow::{Context, Result};
9use candle_core::{DType, Device};
10use itertools::Itertools;
11use tracing::{info, warn};
12
13use super::DeviceMappedModelLoader;
14
15#[derive(Clone, Debug)]
16pub(crate) enum NonMappedSubModel {
17 Vision,
18 Audio,
19}
20
21impl Display for NonMappedSubModel {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 match self {
24 NonMappedSubModel::Vision => write!(f, "vision"),
25 NonMappedSubModel::Audio => write!(f, "audio"),
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
31pub enum AutoDeviceMapParams {
32 Text {
33 max_seq_len: usize,
34 max_batch_size: usize,
35 },
36 Vision {
37 max_seq_len: usize,
38 max_batch_size: usize,
39 max_image_shape: (usize, usize),
40 max_num_images: usize,
41 },
42}
43
44impl AutoDeviceMapParams {
45 pub fn maybe_promote_to_vision(&self) -> Self {
46 match *self {
47 Self::Text {
48 max_seq_len,
49 max_batch_size,
50 } => Self::Vision {
51 max_seq_len,
52 max_batch_size,
53 max_image_shape: (
54 Self::DEFAULT_MAX_IMAGE_LENGTH,
55 Self::DEFAULT_MAX_IMAGE_LENGTH,
56 ),
57 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
58 },
59 Self::Vision {
60 max_seq_len,
61 max_batch_size,
62 max_image_shape,
63 max_num_images,
64 } => Self::Vision {
65 max_seq_len,
66 max_batch_size,
67 max_image_shape,
68 max_num_images,
69 },
70 }
71 }
72
73 pub fn max_seq_len(&self) -> usize {
74 match self {
75 Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
76 }
77 }
78}
79
80impl Display for AutoDeviceMapParams {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self {
83 Self::Text {
84 max_seq_len,
85 max_batch_size,
86 } => write!(
87 f,
88 "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
89 ),
90 Self::Vision {
91 max_seq_len,
92 max_batch_size,
93 max_image_shape,
94 max_num_images,
95 } => write!(
96 f,
97 "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
98 ),
99 }
100 }
101}
102
103impl AutoDeviceMapParams {
104 pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
106 pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
107 pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
108 pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
109
110 pub fn default_text() -> Self {
111 Self::Text {
112 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
113 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
114 }
115 }
116
117 pub fn default_vision() -> Self {
118 Self::Vision {
119 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
120 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
121 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
122 max_image_shape: (
123 Self::DEFAULT_MAX_IMAGE_LENGTH,
124 Self::DEFAULT_MAX_IMAGE_LENGTH,
125 ),
126 }
127 }
128}
129
130fn calculate_key_block_shape(
131 model_config: &dyn ModelConfigLike,
132 dtype: DType,
133 block_size: usize,
134) -> (usize, usize, usize, usize) {
135 let element_size = dtype.size_in_bytes();
136 let x = 16 / element_size;
137 (
138 model_config.num_kv_heads(),
139 model_config.k_head_dim() / x,
140 block_size,
141 x,
142 )
143}
144
145fn calculate_value_block_shape(
146 model_config: &dyn ModelConfigLike,
147 block_size: usize,
148) -> (usize, usize, usize) {
149 (
150 model_config.num_kv_heads(),
151 model_config.v_head_dim(),
152 block_size,
153 )
154}
155
156macro_rules! b_to_mb {
157 ($x:expr) => {
158 $x / (1024 * 1024)
159 };
160}
161
162#[allow(clippy::too_many_arguments)]
163pub fn get_device_layers(
165 loader: &dyn DeviceMappedModelLoader,
166 config: &str,
167 num_layers: usize,
168 mut layer_sizes_in_bytes: Vec<usize>,
169 non_mapped_size_in_bytes: usize,
170 total_model_size_in_bytes: usize,
171 devices: &[Device],
172 dtype: DType,
173 params: &AutoDeviceMapParams,
174 prompt_chunksize: usize,
175 paged_attn_config: Option<&PagedAttentionConfig>,
176) -> Result<DeviceMapMetadata> {
177 let mapped_max =
178 loader.mapped_max_act_size_elems(config, params, prompt_chunksize)? * dtype.size_in_bytes();
179 let non_mapped_max =
180 loader.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
181
182 let mut remaining = total_model_size_in_bytes;
183 let max_seq_len = match params {
184 AutoDeviceMapParams::Text { max_seq_len, .. }
185 | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
186 };
187 let max_batch_size = match params {
188 AutoDeviceMapParams::Text { max_batch_size, .. }
189 | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
190 };
191
192 let model_cfg = loader.model_config(config)?;
193 let kv_cache_elems = match paged_attn_config {
194 Some(cfg) => {
195 let cache = calculate_cache_config(
196 cfg.mem_gpu,
197 cfg.mem_cpu,
198 Some(cfg.block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE)),
199 dtype,
200 paged_attn_config
201 .map(|cfg| cfg.cache_type)
202 .unwrap_or_default(),
203 &*model_cfg,
204 &devices[0],
205 &devices.iter().map(|d| Some(d.clone())).collect::<Vec<_>>(),
206 true,
207 )?;
208 let key_shape = calculate_key_block_shape(&*model_cfg, dtype, cache.block_size);
209 let key_sz =
210 cache.num_gpu_blocks * key_shape.0 * key_shape.1 * key_shape.2 * key_shape.3;
211 let val_shape = calculate_value_block_shape(&*model_cfg, cache.block_size);
212 let val_sz = cache.num_gpu_blocks * val_shape.0 * val_shape.1 * val_shape.2;
213 key_sz + val_sz
214 }
215 None => {
216 let key_shape = [
217 max_batch_size,
218 model_cfg.num_kv_heads(),
219 max_seq_len,
220 model_cfg.k_head_dim(),
221 ];
222 let val_shape = [
223 max_batch_size,
224 model_cfg.num_kv_heads(),
225 max_seq_len,
226 model_cfg.v_head_dim(),
227 ];
228 key_shape.iter().product::<usize>() + val_shape.iter().product::<usize>()
229 }
230 };
231 let kv_cache_bytes = kv_cache_elems * dtype.size_in_bytes();
232
233 let mut avail = Vec::new();
235 for dev in [devices, &[Device::Cpu]].concat() {
236 let a = MemoryUsage.get_memory_available(&dev)?;
237 avail.push((a, dev));
238 }
239 avail.reverse();
240 layer_sizes_in_bytes.reverse();
241
242 let mut mappings = Vec::new();
243 info!("Using automatic device mapping parameters: {params}.");
244 if let Some(subs) = loader.non_mapped_sub_models() {
245 let (_, last) = avail.last().unwrap();
246 info!(
247 "The following sub-models will not be device mapped and will be loaded on {}: {}",
248 last.device_pretty_repr(),
249 subs.iter().map(|x| x.to_string()).join(", ")
250 );
251 }
252
253 let mut ordinal = 0;
254 let mut layer = 0;
255 let avail_copy = avail.clone();
256 let mut includes_cpu = false;
257 while remaining > 0 && !avail.is_empty() {
258 let (cap, dev) = avail
259 .pop()
260 .context("No more devices to map to. The model does not fit on this system.")?;
261
262 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
264 let cap = (cap as f64 * 0.90) as usize;
265
266 let required_whole_capacity = if ordinal == 0 {
273 remaining
274 + non_mapped_max.max(mapped_max)
275 + non_mapped_size_in_bytes
276 + kv_cache_bytes * (num_layers - layer)
277 } else {
278 remaining + mapped_max + kv_cache_bytes * (num_layers - layer)
279 };
280
281 let layers_on_dev = if cap >= required_whole_capacity {
282 remaining = 0;
283 num_layers - layer
284 } else {
285 let mut used = mapped_max;
286 let mut used_no_act = 0;
287 let mut count = 0;
288 if ordinal == 0 {
289 used = used.max(non_mapped_max) + non_mapped_size_in_bytes;
290 used_no_act += non_mapped_size_in_bytes;
291 }
292 while let Some(&sz) = layer_sizes_in_bytes.last() {
293 let delta = sz + kv_cache_bytes;
294 if used + delta > cap {
295 break;
296 }
297 layer_sizes_in_bytes.pop();
298 used += delta;
299 used_no_act += delta;
300 count += 1;
301 }
302 if count > 0 {
303 remaining = remaining.saturating_sub(used_no_act);
304 } else {
305 warn!(
306 "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
307 dev.device_pretty_repr(),
308 );
309 ordinal += 1;
310 continue;
311 }
312 count
313 };
314 if !dev.is_cpu() {
315 mappings.push(DeviceLayerMapMetadata {
316 ordinal,
317 layers: layers_on_dev,
318 });
319 ordinal += 1;
320 } else {
321 includes_cpu = true;
322 }
323 layer += layers_on_dev;
324 }
325 if remaining > 0 {
326 let over = b_to_mb!(remaining);
327 anyhow::bail!(
328 "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
329 avail_copy.iter().rev().map(|(a, d)| format!("{} (avail: {}MB)", d.device_pretty_repr(), b_to_mb!(a))).collect::<Vec<_>>(),
330 over
331 );
332 }
333 if paged_attn_config.is_some_and(|_| includes_cpu) {
334 return get_device_layers(
335 loader,
336 config,
337 num_layers,
338 layer_sizes_in_bytes,
339 non_mapped_size_in_bytes,
340 total_model_size_in_bytes,
341 devices,
342 dtype,
343 params,
344 prompt_chunksize,
345 None,
346 );
347 }
348 Ok(DeviceMapMetadata::from_num_device_layers(mappings))
349}