1use std::{fmt::Debug, sync::Arc};
2
3use crate::{
4 pipeline::AutoDeviceMapParams,
5 utils::{debug::DeviceRepr, log::once_log_info},
6 MemoryUsage, Topology, TryIntoDType,
7};
8use candle_core::{DType, Device, DeviceLocation, Result, Tensor};
9use mistralrs_quant::ShardedVarBuilder;
10use serde::Deserialize;
11use tracing::info;
12
13#[derive(Debug, Default, Deserialize, Clone)]
14pub struct DeviceLayerMapMetadata {
15 pub ordinal: usize,
16 pub layers: usize,
17}
18
19#[derive(Debug, Clone)]
20pub enum DeviceMapSetting {
21 Map(DeviceMapMetadata),
23 Auto(AutoDeviceMapParams),
25 DummyNccl { nm_device: Device },
27 Nccl {
29 nm_device: Device,
30 comm: Arc<mistralrs_quant::Comm>,
31 },
32}
33
34#[derive(Debug, Default, Deserialize, Clone)]
35pub struct DeviceMapMetadata {
37 device_layers: Option<Vec<DeviceLayerMapMetadata>>,
38 host_layers: Option<usize>,
39}
40
41impl DeviceMapMetadata {
42 pub fn from_num_device_layers(device_layers: Vec<DeviceLayerMapMetadata>) -> Self {
43 Self {
44 device_layers: Some(device_layers),
45 host_layers: None,
46 }
47 }
48 pub fn dummy() -> Self {
50 Self {
51 device_layers: None,
52 host_layers: None,
53 }
54 }
55}
56
57impl DeviceMapSetting {
58 pub fn dummy() -> Self {
60 Self::Map(DeviceMapMetadata::dummy())
61 }
62 pub fn into_mapper(
63 &self,
64 model_layers: usize,
65 device: &Device,
66 topology: Option<&Topology>,
67 ) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
68 match self {
69 Self::Nccl { nm_device, comm } => {
70 once_log_info("Loading model using a NCCL-parallelized pipeline.");
71 Ok(Box::new(NcclDeviceMapper {
72 nm_device: nm_device.clone(),
73 model_layers,
74 comm: Some(comm.clone()),
75 }))
76 }
77
78 Self::DummyNccl { nm_device } => {
79 once_log_info("Loading model using a NCCL-parallelized pipeline.");
80 Ok(Box::new(NcclDeviceMapper {
81 nm_device: nm_device.clone(),
82 model_layers,
83 comm: None,
84 }))
85 }
86
87 Self::Map(DeviceMapMetadata {
88 device_layers,
89 host_layers,
90 }) => {
91 if let Some(topology) = topology {
92 if topology.0.iter().all(|x| x.is_none()) {
93 return Ok(Box::new(DummyDeviceMapper {
94 nm_device: device.clone(),
95 }));
96 } else {
97 let layers = topology
98 .0
99 .iter()
100 .map(|layer| {
101 layer
102 .as_ref()
103 .map(|x| x.device.clone().unwrap_or(device.clone()))
104 .unwrap_or(device.clone())
105 })
106 .collect::<Vec<_>>();
107
108 info!("Loading model according to the following repeating layer mappings based on topology:");
109 for (i, dev) in layers.iter().enumerate() {
110 info!("Layer {i}: {}", dev.device_pretty_repr());
111 }
112
113 return Ok(Box::new(LayerDeviceMapper {
114 mappings: layers,
115 nm_device: device.clone(),
116 }));
117 }
118 }
119
120 let n_device_layers = if let Some(layers) = &device_layers {
123 layers
124 .iter()
125 .map(|metadata| metadata.layers)
126 .sum::<usize>()
127 .clamp(0, model_layers)
128 } else {
129 return Ok(Box::new(DummyDeviceMapper {
130 nm_device: device.clone(),
131 }));
132 };
133 let n_host_layers =
136 host_layers.unwrap_or(model_layers.saturating_sub(n_device_layers));
137 if n_device_layers + n_host_layers != model_layers {
138 candle_core::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
139 }
140 once_log_info(format!("Model has {model_layers} repeating layers."));
141
142 let mut combined = Vec::with_capacity(model_layers);
144 if device_layers
145 .as_ref()
146 .is_some_and(|layers| layers.len() == 1)
147 {
148 combined.extend(vec![device.clone(); n_device_layers]);
149 } else {
150 let original_seed = if !device.is_cpu() {
151 Some(device.get_current_seed()?)
152 } else {
153 None
154 };
155 for DeviceLayerMapMetadata { ordinal, layers } in
156 device_layers.as_ref().unwrap()
157 {
158 let dev = match device.location() {
159 DeviceLocation::Cpu => Device::Cpu,
160 DeviceLocation::Cuda { gpu_id: device_ord } => {
161 if device_ord == *ordinal {
162 device.clone()
163 } else {
164 Device::new_cuda_with_stream(*ordinal)?
165 }
166 }
167 DeviceLocation::Metal { gpu_id: device_ord } => {
168 if device_ord == *ordinal {
169 device.clone()
170 } else {
171 Device::new_metal(*ordinal)?
172 }
173 }
174 };
175 if !device.is_cpu() {
176 dev.set_seed(original_seed.unwrap())?;
177 }
178 combined.extend(vec![dev; *layers]);
179 }
180 }
181
182 combined.extend(vec![Device::Cpu; n_host_layers]);
184
185 assert_eq!(combined.len(), model_layers);
187
188 {
190 once_log_info(
191 "Loading model according to the following repeating layer mappings:",
192 );
193 let mut start_index = 0;
194 let mut current_dev = &combined[0];
195
196 for (i, variant) in combined.iter().enumerate().skip(1) {
198 if !variant.same_device(current_dev) {
200 once_log_info(format!(
201 "Layers {}-{}: {} ({} GB)",
202 start_index,
203 i - 1,
204 current_dev.device_pretty_repr(),
205 MemoryUsage
206 .get_total_memory(current_dev)?
207 .div_ceil(1024 * 1024 * 1024),
208 ));
209 start_index = i; current_dev = variant;
211 }
212 }
213
214 once_log_info(format!(
215 "Layers {}-{}: {} ({} GB)",
216 start_index,
217 combined.len() - 1,
218 current_dev.device_pretty_repr(),
219 MemoryUsage
220 .get_total_memory(current_dev)?
221 .div_ceil(1024 * 1024 * 1024),
222 ));
223 }
224
225 Ok(Box::new(LayerDeviceMapper {
226 mappings: combined,
227 nm_device: device.clone(),
228 }))
229 }
230 Self::Auto(_) => {
231 candle_core::bail!(".into_mapper does not work on Auto device map, convert it to a Map with DeviceMappedModelLoader::get_device_layers")
232 }
233 }
234 }
235}
236
237pub trait DeviceMapper: Debug {
238 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor>;
241
242 fn set_device(
245 &self,
246 layer: usize,
247 varbuilder: ShardedVarBuilder,
248 loading_isq: bool,
249 ) -> ShardedVarBuilder;
250 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
252 fn get_unique_devices(&self) -> Vec<Device>;
253 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
255 fn set_nm_device(&self, varbuilder: ShardedVarBuilder, loading_isq: bool) -> ShardedVarBuilder;
258 fn num_device_mapping_layers(&self) -> usize;
259 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>>;
260
261 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
263}
264
265#[derive(Debug)]
266pub struct LayerDeviceMapper {
268 mappings: Vec<Device>,
269 nm_device: Device,
270}
271
272impl DeviceMapper for LayerDeviceMapper {
273 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
274 input.to_device(&self.mappings[layer])
275 }
276 fn set_device<'a>(
277 &self,
278 layer: usize,
279 varbuilder: ShardedVarBuilder,
280 loading_isq: bool,
281 ) -> ShardedVarBuilder {
282 if loading_isq {
283 return varbuilder;
284 }
285 varbuilder.set_device(self.mappings[layer].clone())
286 }
287 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
288 if loading_isq {
289 return Some(&self.nm_device);
290 }
291 self.mappings.get(layer)
292 }
293 fn get_unique_devices(&self) -> Vec<Device> {
294 self.mappings.iter().fold(Vec::new(), |mut acc, device| {
295 if !acc.iter().any(|d| d.same_device(device)) {
296 acc.push(device.clone());
297 }
298 acc
299 })
300 }
301 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
302 if loading_isq {
303 x.to_device(&Device::Cpu)
304 } else {
305 x.to_device(&self.nm_device)
306 }
307 }
308 fn set_nm_device<'a>(
309 &self,
310 varbuilder: ShardedVarBuilder,
311 loading_isq: bool,
312 ) -> ShardedVarBuilder {
313 if loading_isq {
314 varbuilder
315 } else {
316 varbuilder.set_device(self.nm_device.clone())
317 }
318 }
319 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
320 dtype
321 .try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
322 .map_err(candle_core::Error::msg)
323 }
324 fn num_device_mapping_layers(&self) -> usize {
325 self.mappings.len()
326 }
327 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
328 let id = mistralrs_quant::Id::new();
329 Ok(Arc::new(mistralrs_quant::Comm::from_device(
330 id,
331 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
332 0,
333 1,
334 )?))
335 }
336}
337
338#[derive(Debug)]
339pub struct DummyDeviceMapper {
340 nm_device: Device,
341}
342
343impl DeviceMapper for DummyDeviceMapper {
344 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
345 Ok(input)
346 }
347 fn set_device<'a>(
348 &self,
349 _: usize,
350 varbuilder: ShardedVarBuilder,
351 loading_isq: bool,
352 ) -> ShardedVarBuilder {
353 if loading_isq {
354 varbuilder.set_device(Device::Cpu)
355 } else {
356 varbuilder.set_device(self.nm_device.clone())
357 }
358 }
359 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
360 Some(&self.nm_device)
361 }
362 fn get_unique_devices(&self) -> Vec<Device> {
363 vec![self.nm_device.clone()]
364 }
365 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
366 if loading_isq {
367 x.to_device(&Device::Cpu)
368 } else {
369 x.to_device(&self.nm_device)
370 }
371 }
372 fn set_nm_device<'a>(
373 &self,
374 varbuilder: ShardedVarBuilder,
375 loading_isq: bool,
376 ) -> ShardedVarBuilder {
377 if loading_isq {
378 varbuilder.set_device(Device::Cpu)
379 } else {
380 varbuilder.set_device(self.nm_device.clone())
381 }
382 }
383 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
384 dtype
385 .try_into_dtype(&[&self.nm_device])
386 .map_err(candle_core::Error::msg)
387 }
388 fn num_device_mapping_layers(&self) -> usize {
389 1
391 }
392 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
393 let id = mistralrs_quant::Id::new();
394 Ok(Arc::new(mistralrs_quant::Comm::from_device(
395 id,
396 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
397 0,
398 1,
399 )?))
400 }
401}
402
403#[derive(Debug)]
404pub struct NcclDeviceMapper {
405 nm_device: Device,
406 model_layers: usize,
407 comm: Option<Arc<mistralrs_quant::Comm>>,
408}
409
410impl DeviceMapper for NcclDeviceMapper {
411 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
412 Ok(input)
413 }
414 fn set_device<'a>(
415 &self,
416 _: usize,
417 varbuilder: ShardedVarBuilder,
418 loading_isq: bool,
419 ) -> ShardedVarBuilder {
420 if loading_isq {
421 varbuilder.set_device(Device::Cpu)
422 } else {
423 varbuilder.set_device(self.nm_device.clone())
424 }
425 }
426 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
427 Some(&self.nm_device)
428 }
429 fn get_unique_devices(&self) -> Vec<Device> {
430 vec![self.nm_device.clone()]
431 }
432 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
433 if loading_isq {
434 x.to_device(&Device::Cpu)
435 } else {
436 x.to_device(&self.nm_device)
437 }
438 }
439 fn set_nm_device<'a>(
440 &self,
441 varbuilder: ShardedVarBuilder,
442 loading_isq: bool,
443 ) -> ShardedVarBuilder {
444 if loading_isq {
445 varbuilder.set_device(Device::Cpu)
446 } else {
447 varbuilder.set_device(self.nm_device.clone())
448 }
449 }
450 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
451 dtype
452 .try_into_dtype(&[&self.nm_device])
453 .map_err(candle_core::Error::msg)
454 }
455 fn num_device_mapping_layers(&self) -> usize {
456 self.model_layers
457 }
458 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
459 if let Some(comm) = &self.comm {
460 Ok(comm.clone())
461 } else {
462 let id = mistralrs_quant::Id::new();
463 Ok(Arc::new(mistralrs_quant::Comm::from_device(
464 id,
465 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
466 0,
467 1,
468 )?))
469 }
470 }
471}
472
473#[derive(Debug)]
474pub struct NcclPipelineParallelMapper {
476 mappings: Vec<(Arc<mistralrs_quant::Comm>, Device)>,
477 nm_device: Device,
478}
479
480impl DeviceMapper for NcclPipelineParallelMapper {
481 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
482 input.to_device(&self.mappings[layer].1)
483 }
484 fn set_device<'a>(
485 &self,
486 layer: usize,
487 varbuilder: ShardedVarBuilder,
488 loading_isq: bool,
489 ) -> ShardedVarBuilder {
490 if loading_isq {
491 return varbuilder;
492 }
493 varbuilder.set_device(self.mappings[layer].1.clone())
494 }
495 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
496 if loading_isq {
497 return Some(&self.nm_device);
498 }
499 self.mappings.get(layer).map(|(_, x)| x)
500 }
501 fn get_unique_devices(&self) -> Vec<Device> {
502 self.mappings
503 .iter()
504 .fold(Vec::new(), |mut acc, (_, device)| {
505 if !acc.iter().any(|d| d.same_device(device)) {
506 acc.push(device.clone());
507 }
508 acc
509 })
510 }
511 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
512 if loading_isq {
513 x.to_device(&Device::Cpu)
514 } else {
515 x.to_device(&self.nm_device)
516 }
517 }
518 fn set_nm_device<'a>(
519 &self,
520 varbuilder: ShardedVarBuilder,
521 loading_isq: bool,
522 ) -> ShardedVarBuilder {
523 if loading_isq {
524 varbuilder
525 } else {
526 varbuilder.set_device(self.nm_device.clone())
527 }
528 }
529 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
530 dtype
531 .try_into_dtype(&self.mappings.iter().map(|(_, x)| x).collect::<Vec<_>>())
532 .map_err(candle_core::Error::msg)
533 }
534 fn num_device_mapping_layers(&self) -> usize {
535 self.mappings.len()
536 }
537 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
538 Ok(self.mappings[layer_idx].0.clone())
539 }
540}
541
542pub fn get_all_similar_devices(base: &Device) -> Result<Vec<Device>> {
544 let mut devices = Vec::new();
545 match base {
546 Device::Cpu => return Ok(vec![Device::Cpu]),
547 Device::Cuda(_) => {
548 let mut ord = 0;
549 let DeviceLocation::Cuda { gpu_id: base_ord } = base.location() else {
550 candle_core::bail!("location and device do not match");
551 };
552 loop {
553 if base_ord == ord {
554 devices.push(base.clone());
555 ord += 1;
556 continue;
557 }
558 if let Ok(dev) = Device::new_cuda(ord) {
560 devices.push(dev);
561 ord += 1;
562 } else {
563 break;
564 }
565 }
566 }
567 #[cfg(not(feature = "metal"))]
568 Device::Metal(_) => {
569 candle_core::bail!("Not compiled with metal features, but have a metal device.");
570 }
571 #[cfg(feature = "metal")]
572 Device::Metal(_) => {
573 let total_ords = metal::Device::all().len();
574 let mut ord = 0;
575 let DeviceLocation::Metal { gpu_id: base_ord } = base.location() else {
576 candle_core::bail!("location and device do not match");
577 };
578 loop {
579 if base_ord == ord {
580 devices.push(base.clone());
581 ord += 1;
582 continue;
583 }
584 if total_ords == ord {
585 break;
586 }
587 if let Ok(dev) = Device::new_metal(ord) {
588 devices.push(dev);
589 ord += 1;
590 } else {
591 break;
592 }
593 }
594 }
595 }
596 Ok(devices)
597}