1use std::{fmt::Debug, sync::Arc};
2
3use crate::{
4 pipeline::AutoDeviceMapParams, utils::debug::DeviceRepr, MemoryUsage, Topology, TryIntoDType,
5};
6use candle_core::{DType, Device, DeviceLocation, Result, Tensor};
7use mistralrs_quant::log::once_log_info;
8use mistralrs_quant::ShardedVarBuilder;
9use serde::Deserialize;
10use tracing::info;
11
12#[derive(Debug, Default, Deserialize, Clone)]
13pub struct DeviceLayerMapMetadata {
14 pub ordinal: usize,
15 pub layers: usize,
16}
17
18#[derive(Debug, Clone)]
19pub enum DeviceMapSetting {
20 Map(DeviceMapMetadata),
22 Auto(AutoDeviceMapParams),
24 DummyNccl { nm_device: Device },
26 Nccl {
28 nm_device: Device,
29 comm: Arc<mistralrs_quant::Comm>,
30 },
31}
32
33#[derive(Debug, Default, Deserialize, Clone)]
34pub struct DeviceMapMetadata {
36 device_layers: Option<Vec<DeviceLayerMapMetadata>>,
37 host_layers: Option<usize>,
38}
39
40impl DeviceMapMetadata {
41 pub fn from_num_device_layers(device_layers: Vec<DeviceLayerMapMetadata>) -> Self {
42 Self {
43 device_layers: Some(device_layers),
44 host_layers: None,
45 }
46 }
47 pub fn dummy() -> Self {
49 Self {
50 device_layers: None,
51 host_layers: None,
52 }
53 }
54}
55
56impl DeviceMapSetting {
57 pub fn dummy() -> Self {
59 Self::Map(DeviceMapMetadata::dummy())
60 }
61 pub fn into_mapper(
62 &self,
63 model_layers: usize,
64 device: &Device,
65 topology: Option<&Topology>,
66 ) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
67 match self {
68 Self::Nccl { nm_device, comm } => {
69 once_log_info("Loading model using a NCCL-parallelized pipeline.");
70 Ok(Box::new(NcclDeviceMapper {
71 nm_device: nm_device.clone(),
72 model_layers,
73 comm: Some(comm.clone()),
74 }))
75 }
76
77 Self::DummyNccl { nm_device } => {
78 once_log_info("Loading model using a NCCL-parallelized pipeline.");
79 Ok(Box::new(NcclDeviceMapper {
80 nm_device: nm_device.clone(),
81 model_layers,
82 comm: None,
83 }))
84 }
85
86 Self::Map(DeviceMapMetadata {
87 device_layers,
88 host_layers,
89 }) => {
90 if let Some(topology) = topology {
91 if topology.0.iter().all(|x| x.is_none()) {
92 return Ok(Box::new(DummyDeviceMapper {
93 nm_device: device.clone(),
94 }));
95 } else {
96 let layers = topology
97 .0
98 .iter()
99 .map(|layer| {
100 layer
101 .as_ref()
102 .map(|x| x.device.clone().unwrap_or(device.clone()))
103 .unwrap_or(device.clone())
104 })
105 .collect::<Vec<_>>();
106
107 info!("Loading model according to the following repeating layer mappings based on topology:");
108 for (i, dev) in layers.iter().enumerate() {
109 info!("Layer {i}: {}", dev.device_pretty_repr());
110 }
111
112 return Ok(Box::new(LayerDeviceMapper {
113 mappings: layers,
114 nm_device: device.clone(),
115 }));
116 }
117 }
118
119 let n_device_layers = if let Some(layers) = &device_layers {
122 layers
123 .iter()
124 .map(|metadata| metadata.layers)
125 .sum::<usize>()
126 .clamp(0, model_layers)
127 } else {
128 return Ok(Box::new(DummyDeviceMapper {
129 nm_device: device.clone(),
130 }));
131 };
132 let n_host_layers =
135 host_layers.unwrap_or(model_layers.saturating_sub(n_device_layers));
136 if n_device_layers + n_host_layers != model_layers {
137 candle_core::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
138 }
139 once_log_info(format!("Model has {model_layers} repeating layers."));
140
141 let mut combined = Vec::with_capacity(model_layers);
143 if device_layers
144 .as_ref()
145 .is_some_and(|layers| layers.len() == 1)
146 {
147 combined.extend(vec![device.clone(); n_device_layers]);
148 } else {
149 let original_seed = if !device.is_cpu() {
150 Some(device.get_current_seed()?)
151 } else {
152 None
153 };
154 for DeviceLayerMapMetadata { ordinal, layers } in
155 device_layers.as_ref().unwrap()
156 {
157 let dev = match device.location() {
158 DeviceLocation::Cpu => Device::Cpu,
159 DeviceLocation::Cuda { gpu_id: device_ord } => {
160 if device_ord == *ordinal {
161 device.clone()
162 } else {
163 Device::new_cuda(*ordinal)?
164 }
165 }
166 DeviceLocation::Metal { gpu_id: device_ord } => {
167 if device_ord == *ordinal {
168 device.clone()
169 } else {
170 Device::new_metal(*ordinal)?
171 }
172 }
173 };
174 if !device.is_cpu() {
175 dev.set_seed(original_seed.unwrap())?;
176 }
177 combined.extend(vec![dev; *layers]);
178 }
179 }
180
181 combined.extend(vec![Device::Cpu; n_host_layers]);
183
184 assert_eq!(combined.len(), model_layers);
186
187 {
189 once_log_info(
190 "Loading model according to the following repeating layer mappings:",
191 );
192 let mut start_index = 0;
193 let mut current_dev = &combined[0];
194
195 for (i, variant) in combined.iter().enumerate().skip(1) {
197 if !variant.same_device(current_dev) {
199 once_log_info(format!(
200 "Layers {}-{}: {} ({} GB)",
201 start_index,
202 i - 1,
203 current_dev.device_pretty_repr(),
204 MemoryUsage
205 .get_total_memory(current_dev)?
206 .div_ceil(1024 * 1024 * 1024),
207 ));
208 start_index = i; current_dev = variant;
210 }
211 }
212
213 once_log_info(format!(
214 "Layers {}-{}: {} ({} GB)",
215 start_index,
216 combined.len() - 1,
217 current_dev.device_pretty_repr(),
218 MemoryUsage
219 .get_total_memory(current_dev)?
220 .div_ceil(1024 * 1024 * 1024),
221 ));
222 }
223
224 Ok(Box::new(LayerDeviceMapper {
225 mappings: combined,
226 nm_device: device.clone(),
227 }))
228 }
229 Self::Auto(_) => {
230 candle_core::bail!(".into_mapper does not work on Auto device map, convert it to a Map with DeviceMappedModelLoader::get_device_layers")
231 }
232 }
233 }
234}
235
236pub trait DeviceMapper: Debug {
237 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor>;
240
241 fn set_device(
244 &self,
245 layer: usize,
246 varbuilder: ShardedVarBuilder,
247 loading_isq: bool,
248 ) -> ShardedVarBuilder;
249 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
251 fn get_unique_devices(&self) -> Vec<Device>;
252 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
254 fn set_nm_device(&self, varbuilder: ShardedVarBuilder, loading_isq: bool) -> ShardedVarBuilder;
257 fn num_device_mapping_layers(&self) -> usize;
258 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>>;
259
260 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
262}
263
264#[derive(Debug)]
265pub struct LayerDeviceMapper {
267 mappings: Vec<Device>,
268 nm_device: Device,
269}
270
271impl DeviceMapper for LayerDeviceMapper {
272 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
273 input.to_device(&self.mappings[layer])
274 }
275 fn set_device<'a>(
276 &self,
277 layer: usize,
278 varbuilder: ShardedVarBuilder,
279 loading_isq: bool,
280 ) -> ShardedVarBuilder {
281 if loading_isq {
282 return varbuilder;
283 }
284 varbuilder.set_device(self.mappings[layer].clone())
285 }
286 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
287 if loading_isq {
288 return Some(&self.nm_device);
289 }
290 self.mappings.get(layer)
291 }
292 fn get_unique_devices(&self) -> Vec<Device> {
293 self.mappings.iter().fold(Vec::new(), |mut acc, device| {
294 if !acc.iter().any(|d| d.same_device(device)) {
295 acc.push(device.clone());
296 }
297 acc
298 })
299 }
300 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
301 if loading_isq {
302 x.to_device(&Device::Cpu)
303 } else {
304 x.to_device(&self.nm_device)
305 }
306 }
307 fn set_nm_device<'a>(
308 &self,
309 varbuilder: ShardedVarBuilder,
310 loading_isq: bool,
311 ) -> ShardedVarBuilder {
312 if loading_isq {
313 varbuilder
314 } else {
315 varbuilder.set_device(self.nm_device.clone())
316 }
317 }
318 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
319 dtype
320 .try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
321 .map_err(candle_core::Error::msg)
322 }
323 fn num_device_mapping_layers(&self) -> usize {
324 self.mappings.len()
325 }
326 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
327 let id = mistralrs_quant::Id::new();
328 Ok(Arc::new(mistralrs_quant::Comm::from_device(
329 id,
330 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
331 0,
332 1,
333 )?))
334 }
335}
336
337#[derive(Debug)]
338pub struct DummyDeviceMapper {
339 pub(crate) nm_device: Device,
340}
341
342impl DeviceMapper for DummyDeviceMapper {
343 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
344 Ok(input)
345 }
346 fn set_device<'a>(
347 &self,
348 _: usize,
349 varbuilder: ShardedVarBuilder,
350 loading_isq: bool,
351 ) -> ShardedVarBuilder {
352 if loading_isq {
353 varbuilder.set_device(Device::Cpu)
354 } else {
355 varbuilder.set_device(self.nm_device.clone())
356 }
357 }
358 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
359 Some(&self.nm_device)
360 }
361 fn get_unique_devices(&self) -> Vec<Device> {
362 vec![self.nm_device.clone()]
363 }
364 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
365 if loading_isq {
366 x.to_device(&Device::Cpu)
367 } else {
368 x.to_device(&self.nm_device)
369 }
370 }
371 fn set_nm_device<'a>(
372 &self,
373 varbuilder: ShardedVarBuilder,
374 loading_isq: bool,
375 ) -> ShardedVarBuilder {
376 if loading_isq {
377 varbuilder.set_device(Device::Cpu)
378 } else {
379 varbuilder.set_device(self.nm_device.clone())
380 }
381 }
382 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
383 dtype
384 .try_into_dtype(&[&self.nm_device])
385 .map_err(candle_core::Error::msg)
386 }
387 fn num_device_mapping_layers(&self) -> usize {
388 1
390 }
391 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
392 let id = mistralrs_quant::Id::new();
393 Ok(Arc::new(mistralrs_quant::Comm::from_device(
394 id,
395 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
396 0,
397 1,
398 )?))
399 }
400}
401
402#[derive(Debug)]
403pub struct NcclDeviceMapper {
404 nm_device: Device,
405 model_layers: usize,
406 comm: Option<Arc<mistralrs_quant::Comm>>,
407}
408
409impl DeviceMapper for NcclDeviceMapper {
410 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
411 Ok(input)
412 }
413 fn set_device<'a>(
414 &self,
415 _: usize,
416 varbuilder: ShardedVarBuilder,
417 loading_isq: bool,
418 ) -> ShardedVarBuilder {
419 if loading_isq {
420 varbuilder.set_device(Device::Cpu)
421 } else {
422 varbuilder.set_device(self.nm_device.clone())
423 }
424 }
425 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
426 Some(&self.nm_device)
427 }
428 fn get_unique_devices(&self) -> Vec<Device> {
429 vec![self.nm_device.clone()]
430 }
431 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
432 if loading_isq {
433 x.to_device(&Device::Cpu)
434 } else {
435 x.to_device(&self.nm_device)
436 }
437 }
438 fn set_nm_device<'a>(
439 &self,
440 varbuilder: ShardedVarBuilder,
441 loading_isq: bool,
442 ) -> ShardedVarBuilder {
443 if loading_isq {
444 varbuilder.set_device(Device::Cpu)
445 } else {
446 varbuilder.set_device(self.nm_device.clone())
447 }
448 }
449 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
450 dtype
451 .try_into_dtype(&[&self.nm_device])
452 .map_err(candle_core::Error::msg)
453 }
454 fn num_device_mapping_layers(&self) -> usize {
455 self.model_layers
456 }
457 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
458 if let Some(comm) = &self.comm {
459 Ok(comm.clone())
460 } else {
461 let id = mistralrs_quant::Id::new();
462 Ok(Arc::new(mistralrs_quant::Comm::from_device(
463 id,
464 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
465 0,
466 1,
467 )?))
468 }
469 }
470}
471
472#[derive(Debug)]
473#[allow(dead_code)]
474pub struct NcclPipelineParallelMapper {
476 mappings: Vec<(Arc<mistralrs_quant::Comm>, Device)>,
477 nm_device: Device,
478}
479
480impl DeviceMapper for NcclPipelineParallelMapper {
481 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
482 input.to_device(&self.mappings[layer].1)
483 }
484 fn set_device<'a>(
485 &self,
486 layer: usize,
487 varbuilder: ShardedVarBuilder,
488 loading_isq: bool,
489 ) -> ShardedVarBuilder {
490 if loading_isq {
491 return varbuilder;
492 }
493 varbuilder.set_device(self.mappings[layer].1.clone())
494 }
495 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
496 if loading_isq {
497 return Some(&self.nm_device);
498 }
499 self.mappings.get(layer).map(|(_, x)| x)
500 }
501 fn get_unique_devices(&self) -> Vec<Device> {
502 self.mappings
503 .iter()
504 .fold(Vec::new(), |mut acc, (_, device)| {
505 if !acc.iter().any(|d| d.same_device(device)) {
506 acc.push(device.clone());
507 }
508 acc
509 })
510 }
511 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
512 if loading_isq {
513 x.to_device(&Device::Cpu)
514 } else {
515 x.to_device(&self.nm_device)
516 }
517 }
518 fn set_nm_device<'a>(
519 &self,
520 varbuilder: ShardedVarBuilder,
521 loading_isq: bool,
522 ) -> ShardedVarBuilder {
523 if loading_isq {
524 varbuilder
525 } else {
526 varbuilder.set_device(self.nm_device.clone())
527 }
528 }
529 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
530 dtype
531 .try_into_dtype(&self.mappings.iter().map(|(_, x)| x).collect::<Vec<_>>())
532 .map_err(candle_core::Error::msg)
533 }
534 fn num_device_mapping_layers(&self) -> usize {
535 self.mappings.len()
536 }
537 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
538 Ok(self.mappings[layer_idx].0.clone())
539 }
540}
541
542pub fn get_all_similar_devices(base: &Device) -> Result<Vec<Device>> {
544 let mut devices = Vec::new();
545 match base {
546 Device::Cpu => return Ok(vec![Device::Cpu]),
547 Device::Cuda(_) => {
548 let mut ord = 0;
549 let DeviceLocation::Cuda { gpu_id: base_ord } = base.location() else {
550 candle_core::bail!("location and device do not match");
551 };
552 loop {
553 if base_ord == ord {
554 devices.push(base.clone());
555 ord += 1;
556 continue;
557 }
558 if let Ok(dev) = Device::new_cuda(ord) {
560 devices.push(dev);
561 ord += 1;
562 } else {
563 break;
564 }
565 }
566 }
567 #[cfg(not(feature = "metal"))]
568 Device::Metal(_) => {
569 candle_core::bail!("Not compiled with metal features, but have a metal device.");
570 }
571 #[cfg(feature = "metal")]
572 Device::Metal(_) => {
573 let total_ords = metal::Device::all().len();
574 let mut ord = 0;
575 let DeviceLocation::Metal { gpu_id: base_ord } = base.location() else {
576 candle_core::bail!("location and device do not match");
577 };
578 loop {
579 if base_ord == ord {
580 devices.push(base.clone());
581 ord += 1;
582 continue;
583 }
584 if total_ords == ord {
585 break;
586 }
587 if let Ok(dev) = Device::new_metal(ord) {
588 devices.push(dev);
589 ord += 1;
590 } else {
591 break;
592 }
593 }
594 }
595 }
596 Ok(devices)
597}