mistralrs_core/pipeline/loaders/
auto_device_map.rs1use std::fmt::{self, Display};
2
3use crate::paged_attention::{
4 calculate_cache_config, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
5};
6use crate::utils::debug::DeviceRepr;
7use crate::{DeviceLayerMapMetadata, DeviceMapMetadata, MemoryUsage, PagedAttentionConfig};
8use anyhow::{Context, Result};
9use candle_core::{DType, Device};
10use itertools::Itertools;
11use tracing::{info, warn};
12
13use super::DeviceMappedModelLoader;
14
15#[derive(Clone, Debug)]
16pub(crate) enum NonMappedSubModel {
17 Vision,
18}
19
20impl Display for NonMappedSubModel {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 match self {
23 NonMappedSubModel::Vision => write!(f, "vision"),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
29pub enum AutoDeviceMapParams {
30 Text {
31 max_seq_len: usize,
32 max_batch_size: usize,
33 },
34 Vision {
35 max_seq_len: usize,
36 max_batch_size: usize,
37 max_image_shape: (usize, usize),
38 max_num_images: usize,
39 },
40}
41
42impl AutoDeviceMapParams {
43 pub fn maybe_promote_to_vision(&self) -> Self {
44 match *self {
45 Self::Text {
46 max_seq_len,
47 max_batch_size,
48 } => Self::Vision {
49 max_seq_len,
50 max_batch_size,
51 max_image_shape: (
52 Self::DEFAULT_MAX_IMAGE_LENGTH,
53 Self::DEFAULT_MAX_IMAGE_LENGTH,
54 ),
55 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
56 },
57 Self::Vision {
58 max_seq_len,
59 max_batch_size,
60 max_image_shape,
61 max_num_images,
62 } => Self::Vision {
63 max_seq_len,
64 max_batch_size,
65 max_image_shape,
66 max_num_images,
67 },
68 }
69 }
70
71 pub fn max_seq_len(&self) -> usize {
72 match self {
73 Self::Text { max_seq_len, .. } | Self::Vision { max_seq_len, .. } => *max_seq_len,
74 }
75 }
76}
77
78impl Display for AutoDeviceMapParams {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match self {
81 Self::Text {
82 max_seq_len,
83 max_batch_size,
84 } => write!(
85 f,
86 "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
87 ),
88 Self::Vision {
89 max_seq_len,
90 max_batch_size,
91 max_image_shape,
92 max_num_images,
93 } => write!(
94 f,
95 "vision[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
96 ),
97 }
98 }
99}
100
101impl AutoDeviceMapParams {
102 pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
103 pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
104 pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
105 pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
106
107 pub fn default_text() -> Self {
108 Self::Text {
109 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
110 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
111 }
112 }
113
114 pub fn default_vision() -> Self {
115 Self::Vision {
116 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
117 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
118 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
119 max_image_shape: (
120 Self::DEFAULT_MAX_IMAGE_LENGTH,
121 Self::DEFAULT_MAX_IMAGE_LENGTH,
122 ),
123 }
124 }
125}
126
127fn calculate_key_block_shape(
128 model_config: &dyn ModelConfigLike,
129 dtype: DType,
130 block_size: usize,
131) -> (usize, usize, usize, usize) {
132 let element_size = dtype.size_in_bytes();
133 let x = 16 / element_size;
134 (
135 model_config.num_kv_heads(),
136 model_config.k_head_dim() / x,
137 block_size,
138 x,
139 )
140}
141
142fn calculate_value_block_shape(
143 model_config: &dyn ModelConfigLike,
144 block_size: usize,
145) -> (usize, usize, usize) {
146 (
147 model_config.num_kv_heads(),
148 model_config.v_head_dim(),
149 block_size,
150 )
151}
152
153macro_rules! b_to_mb {
154 ($x:expr) => {
155 $x / (1024 * 1024)
156 };
157}
158
159#[allow(clippy::too_many_arguments)]
160pub fn get_device_layers(
162 loader: &dyn DeviceMappedModelLoader,
163 config: &str,
164 num_layers: usize,
165 mut layer_sizes_in_bytes: Vec<usize>,
166 non_mapped_size_in_bytes: usize,
167 total_model_size_in_bytes: usize,
168 devices: &[Device],
169 dtype: DType,
170 params: &AutoDeviceMapParams,
171 prompt_chunksize: usize,
172 paged_attn_config: Option<&PagedAttentionConfig>,
173) -> Result<DeviceMapMetadata> {
174 let mapped_max =
175 loader.mapped_max_act_size_elems(config, params, prompt_chunksize)? * dtype.size_in_bytes();
176 let non_mapped_max =
177 loader.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
178
179 let mut remaining = total_model_size_in_bytes;
180 let max_seq_len = match params {
181 AutoDeviceMapParams::Text { max_seq_len, .. }
182 | AutoDeviceMapParams::Vision { max_seq_len, .. } => *max_seq_len,
183 };
184 let max_batch_size = match params {
185 AutoDeviceMapParams::Text { max_batch_size, .. }
186 | AutoDeviceMapParams::Vision { max_batch_size, .. } => *max_batch_size,
187 };
188
189 let model_cfg = loader.model_config(config)?;
190 let kv_cache_elems = match paged_attn_config {
191 Some(cfg) => {
192 let cache = calculate_cache_config(
193 cfg.mem_gpu,
194 cfg.mem_cpu,
195 Some(cfg.block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE)),
196 dtype,
197 &*model_cfg,
198 &devices[0],
199 &devices.iter().map(|d| Some(d.clone())).collect::<Vec<_>>(),
200 true,
201 )?;
202 let key_shape = calculate_key_block_shape(&*model_cfg, dtype, cache.block_size);
203 let key_sz =
204 cache.num_gpu_blocks * key_shape.0 * key_shape.1 * key_shape.2 * key_shape.3;
205 let val_shape = calculate_value_block_shape(&*model_cfg, cache.block_size);
206 let val_sz = cache.num_gpu_blocks * val_shape.0 * val_shape.1 * val_shape.2;
207 key_sz + val_sz
208 }
209 None => {
210 let key_shape = [
211 max_batch_size,
212 model_cfg.num_kv_heads(),
213 max_seq_len,
214 model_cfg.k_head_dim(),
215 ];
216 let val_shape = [
217 max_batch_size,
218 model_cfg.num_kv_heads(),
219 max_seq_len,
220 model_cfg.v_head_dim(),
221 ];
222 key_shape.iter().product::<usize>() + val_shape.iter().product::<usize>()
223 }
224 };
225 let kv_cache_bytes = kv_cache_elems * dtype.size_in_bytes();
226
227 let mut avail = Vec::new();
229 for dev in [devices, &[Device::Cpu]].concat() {
230 let a = MemoryUsage.get_memory_available(&dev)?;
231 avail.push((a, dev));
232 }
233 avail.reverse();
234 layer_sizes_in_bytes.reverse();
235
236 let mut mappings = Vec::new();
237 info!("Using automatic device mapping parameters: {params}.");
238 if let Some(subs) = loader.non_mapped_sub_models() {
239 let (_, last) = avail.last().unwrap();
240 info!(
241 "The following sub-models will not be device mapped and will be loaded on {}: {}",
242 last.device_pretty_repr(),
243 subs.iter().map(|x| x.to_string()).join(", ")
244 );
245 }
246
247 let mut ordinal = 0;
248 let mut layer = 0;
249 let avail_copy = avail.clone();
250 let mut includes_cpu = false;
251 while remaining > 0 && !avail.is_empty() {
252 let (cap, dev) = avail
253 .pop()
254 .context("No more devices to map to. The model does not fit on this system.")?;
255
256 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
258 let cap = (cap as f64 * 0.90) as usize;
259
260 let required_whole_capacity = if ordinal == 0 {
267 remaining
268 + non_mapped_max.max(mapped_max)
269 + non_mapped_size_in_bytes
270 + kv_cache_bytes * (num_layers - layer)
271 } else {
272 remaining + mapped_max + kv_cache_bytes * (num_layers - layer)
273 };
274
275 let layers_on_dev = if cap >= required_whole_capacity {
276 remaining = 0;
277 num_layers - layer
278 } else {
279 let mut used = mapped_max;
280 let mut used_no_act = 0;
281 let mut count = 0;
282 if ordinal == 0 {
283 used = used.max(non_mapped_max) + non_mapped_size_in_bytes;
284 used_no_act += non_mapped_size_in_bytes;
285 }
286 while let Some(&sz) = layer_sizes_in_bytes.last() {
287 let delta = sz + kv_cache_bytes;
288 if used + delta > cap {
289 break;
290 }
291 layer_sizes_in_bytes.pop();
292 used += delta;
293 used_no_act += delta;
294 count += 1;
295 }
296 if count > 0 {
297 remaining = remaining.saturating_sub(used_no_act);
298 } else {
299 warn!(
300 "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
301 dev.device_pretty_repr(),
302 );
303 ordinal += 1;
304 continue;
305 }
306 count
307 };
308 if !dev.is_cpu() {
309 mappings.push(DeviceLayerMapMetadata {
310 ordinal,
311 layers: layers_on_dev,
312 });
313 ordinal += 1;
314 } else {
315 includes_cpu = true;
316 }
317 layer += layers_on_dev;
318 }
319 if remaining > 0 {
320 let over = b_to_mb!(remaining);
321 anyhow::bail!(
322 "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
323 avail_copy.iter().rev().map(|(a, d)| format!("{} (avail: {}MB)", d.device_pretty_repr(), b_to_mb!(a))).collect::<Vec<_>>(),
324 over
325 );
326 }
327 if paged_attn_config.is_some_and(|_| includes_cpu) {
328 return get_device_layers(
329 loader,
330 config,
331 num_layers,
332 layer_sizes_in_bytes,
333 non_mapped_size_in_bytes,
334 total_model_size_in_bytes,
335 devices,
336 dtype,
337 params,
338 prompt_chunksize,
339 None,
340 );
341 }
342 Ok(DeviceMapMetadata::from_num_device_layers(mappings))
343}