1use std::{fmt::Debug, sync::Arc};
2
3use crate::{
4 pipeline::AutoDeviceMapParams, utils::debug::DeviceRepr, MemoryUsage, Topology, TryIntoDType,
5};
6use candle_core::{DType, Device, DeviceLocation, Result, Tensor};
7use mistralrs_quant::log::once_log_info;
8use mistralrs_quant::ShardedVarBuilder;
9use serde::Deserialize;
10use tracing::info;
11
12#[derive(Debug, Default, Deserialize, Clone)]
13pub struct DeviceLayerMapMetadata {
14 pub ordinal: usize,
15 pub layers: usize,
16}
17
18#[derive(Debug, Clone)]
19pub enum DeviceMapSetting {
20 Map(DeviceMapMetadata),
22 Auto(AutoDeviceMapParams),
24 DummyNccl { nm_device: Device },
26 Nccl {
28 nm_device: Device,
29 comm: Arc<mistralrs_quant::Comm>,
30 },
31}
32
33#[derive(Debug, Default, Deserialize, Clone)]
34pub struct DeviceMapMetadata {
36 device_layers: Option<Vec<DeviceLayerMapMetadata>>,
37 host_layers: Option<usize>,
38}
39
40impl DeviceMapMetadata {
41 pub fn from_num_device_layers(device_layers: Vec<DeviceLayerMapMetadata>) -> Self {
42 Self {
43 device_layers: Some(device_layers),
44 host_layers: None,
45 }
46 }
47 pub fn dummy() -> Self {
49 Self {
50 device_layers: None,
51 host_layers: None,
52 }
53 }
54}
55
56impl DeviceMapSetting {
57 pub fn dummy() -> Self {
59 Self::Map(DeviceMapMetadata::dummy())
60 }
61 pub fn into_mapper(
62 &self,
63 model_layers: usize,
64 device: &Device,
65 topology: Option<&Topology>,
66 ) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
67 match self {
68 Self::Nccl { nm_device, comm } => {
69 once_log_info("Loading model using a NCCL-parallelized pipeline.");
70 Ok(Box::new(NcclDeviceMapper {
71 nm_device: nm_device.clone(),
72 model_layers,
73 comm: Some(comm.clone()),
74 }))
75 }
76
77 Self::DummyNccl { nm_device } => {
78 once_log_info("Loading model using a NCCL-parallelized pipeline.");
79 Ok(Box::new(NcclDeviceMapper {
80 nm_device: nm_device.clone(),
81 model_layers,
82 comm: None,
83 }))
84 }
85
86 Self::Map(DeviceMapMetadata {
87 device_layers,
88 host_layers,
89 }) => {
90 if let Some(topology) = topology {
91 if topology.0.iter().all(|x| x.is_none()) {
92 return Ok(Box::new(DummyDeviceMapper {
93 nm_device: device.clone(),
94 }));
95 } else {
96 let layers = topology
97 .0
98 .iter()
99 .map(|layer| {
100 layer
101 .as_ref()
102 .map(|x| x.device.clone().unwrap_or(device.clone()))
103 .unwrap_or(device.clone())
104 })
105 .collect::<Vec<_>>();
106
107 info!("Loading model according to the following repeating layer mappings based on topology:");
108 for (i, dev) in layers.iter().enumerate() {
109 info!("Layer {i}: {}", dev.device_pretty_repr());
110 }
111
112 return Ok(Box::new(LayerDeviceMapper {
113 mappings: layers,
114 nm_device: device.clone(),
115 }));
116 }
117 }
118
119 let n_device_layers = if let Some(layers) = &device_layers {
122 layers
123 .iter()
124 .map(|metadata| metadata.layers)
125 .sum::<usize>()
126 .clamp(0, model_layers)
127 } else {
128 return Ok(Box::new(DummyDeviceMapper {
129 nm_device: device.clone(),
130 }));
131 };
132 let n_host_layers =
135 host_layers.unwrap_or(model_layers.saturating_sub(n_device_layers));
136 if n_device_layers + n_host_layers != model_layers {
137 candle_core::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
138 }
139 once_log_info(format!("Model has {model_layers} repeating layers."));
140
141 let mut combined = Vec::with_capacity(model_layers);
143 if device_layers
144 .as_ref()
145 .is_some_and(|layers| layers.len() == 1)
146 {
147 combined.extend(vec![device.clone(); n_device_layers]);
148 } else {
149 let original_seed = if !device.is_cpu() {
150 Some(device.get_current_seed()?)
151 } else {
152 None
153 };
154 for DeviceLayerMapMetadata { ordinal, layers } in
155 device_layers.as_ref().unwrap()
156 {
157 let dev = match device.location() {
158 DeviceLocation::Cpu => Device::Cpu,
159 DeviceLocation::Cuda { gpu_id: device_ord } => {
160 if device_ord == *ordinal {
161 device.clone()
162 } else {
163 Device::new_cuda_with_stream(*ordinal)?
164 }
165 }
166 DeviceLocation::Metal { gpu_id: device_ord } => {
167 if device_ord == *ordinal {
168 device.clone()
169 } else {
170 Device::new_metal(*ordinal)?
171 }
172 }
173 };
174 if !device.is_cpu() {
175 dev.set_seed(original_seed.unwrap())?;
176 }
177 combined.extend(vec![dev; *layers]);
178 }
179 }
180
181 combined.extend(vec![Device::Cpu; n_host_layers]);
183
184 assert_eq!(combined.len(), model_layers);
186
187 {
189 once_log_info(
190 "Loading model according to the following repeating layer mappings:",
191 );
192 let mut start_index = 0;
193 let mut current_dev = &combined[0];
194
195 for (i, variant) in combined.iter().enumerate().skip(1) {
197 if !variant.same_device(current_dev) {
199 once_log_info(format!(
200 "Layers {}-{}: {} ({} GB)",
201 start_index,
202 i - 1,
203 current_dev.device_pretty_repr(),
204 MemoryUsage
205 .get_total_memory(current_dev)?
206 .div_ceil(1024 * 1024 * 1024),
207 ));
208 start_index = i; current_dev = variant;
210 }
211 }
212
213 once_log_info(format!(
214 "Layers {}-{}: {} ({} GB)",
215 start_index,
216 combined.len() - 1,
217 current_dev.device_pretty_repr(),
218 MemoryUsage
219 .get_total_memory(current_dev)?
220 .div_ceil(1024 * 1024 * 1024),
221 ));
222 }
223
224 Ok(Box::new(LayerDeviceMapper {
225 mappings: combined,
226 nm_device: device.clone(),
227 }))
228 }
229 Self::Auto(_) => {
230 candle_core::bail!(".into_mapper does not work on Auto device map, convert it to a Map with DeviceMappedModelLoader::get_device_layers")
231 }
232 }
233 }
234}
235
236pub trait DeviceMapper: Debug {
237 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor>;
240
241 fn set_device(
244 &self,
245 layer: usize,
246 varbuilder: ShardedVarBuilder,
247 loading_isq: bool,
248 ) -> ShardedVarBuilder;
249 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
251 fn get_unique_devices(&self) -> Vec<Device>;
252 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
254 fn set_nm_device(&self, varbuilder: ShardedVarBuilder, loading_isq: bool) -> ShardedVarBuilder;
257 fn num_device_mapping_layers(&self) -> usize;
258 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>>;
259
260 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
262}
263
264#[derive(Debug)]
265pub struct LayerDeviceMapper {
267 mappings: Vec<Device>,
268 nm_device: Device,
269}
270
271impl DeviceMapper for LayerDeviceMapper {
272 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
273 input.to_device(&self.mappings[layer])
274 }
275 fn set_device<'a>(
276 &self,
277 layer: usize,
278 varbuilder: ShardedVarBuilder,
279 loading_isq: bool,
280 ) -> ShardedVarBuilder {
281 if loading_isq {
282 return varbuilder;
283 }
284 varbuilder.set_device(self.mappings[layer].clone())
285 }
286 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
287 if loading_isq {
288 return Some(&self.nm_device);
289 }
290 self.mappings.get(layer)
291 }
292 fn get_unique_devices(&self) -> Vec<Device> {
293 self.mappings.iter().fold(Vec::new(), |mut acc, device| {
294 if !acc.iter().any(|d| d.same_device(device)) {
295 acc.push(device.clone());
296 }
297 acc
298 })
299 }
300 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
301 if loading_isq {
302 x.to_device(&Device::Cpu)
303 } else {
304 x.to_device(&self.nm_device)
305 }
306 }
307 fn set_nm_device<'a>(
308 &self,
309 varbuilder: ShardedVarBuilder,
310 loading_isq: bool,
311 ) -> ShardedVarBuilder {
312 if loading_isq {
313 varbuilder
314 } else {
315 varbuilder.set_device(self.nm_device.clone())
316 }
317 }
318 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
319 dtype
320 .try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
321 .map_err(candle_core::Error::msg)
322 }
323 fn num_device_mapping_layers(&self) -> usize {
324 self.mappings.len()
325 }
326 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
327 let id = mistralrs_quant::Id::new();
328 Ok(Arc::new(mistralrs_quant::Comm::from_device(
329 id,
330 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
331 0,
332 1,
333 )?))
334 }
335}
336
337#[derive(Debug)]
338pub struct DummyDeviceMapper {
339 nm_device: Device,
340}
341
342impl DeviceMapper for DummyDeviceMapper {
343 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
344 Ok(input)
345 }
346 fn set_device<'a>(
347 &self,
348 _: usize,
349 varbuilder: ShardedVarBuilder,
350 loading_isq: bool,
351 ) -> ShardedVarBuilder {
352 if loading_isq {
353 varbuilder.set_device(Device::Cpu)
354 } else {
355 varbuilder.set_device(self.nm_device.clone())
356 }
357 }
358 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
359 Some(&self.nm_device)
360 }
361 fn get_unique_devices(&self) -> Vec<Device> {
362 vec![self.nm_device.clone()]
363 }
364 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
365 if loading_isq {
366 x.to_device(&Device::Cpu)
367 } else {
368 x.to_device(&self.nm_device)
369 }
370 }
371 fn set_nm_device<'a>(
372 &self,
373 varbuilder: ShardedVarBuilder,
374 loading_isq: bool,
375 ) -> ShardedVarBuilder {
376 if loading_isq {
377 varbuilder.set_device(Device::Cpu)
378 } else {
379 varbuilder.set_device(self.nm_device.clone())
380 }
381 }
382 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
383 dtype
384 .try_into_dtype(&[&self.nm_device])
385 .map_err(candle_core::Error::msg)
386 }
387 fn num_device_mapping_layers(&self) -> usize {
388 1
390 }
391 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
392 let id = mistralrs_quant::Id::new();
393 Ok(Arc::new(mistralrs_quant::Comm::from_device(
394 id,
395 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
396 0,
397 1,
398 )?))
399 }
400}
401
402#[derive(Debug)]
403pub struct NcclDeviceMapper {
404 nm_device: Device,
405 model_layers: usize,
406 comm: Option<Arc<mistralrs_quant::Comm>>,
407}
408
409impl DeviceMapper for NcclDeviceMapper {
410 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
411 Ok(input)
412 }
413 fn set_device<'a>(
414 &self,
415 _: usize,
416 varbuilder: ShardedVarBuilder,
417 loading_isq: bool,
418 ) -> ShardedVarBuilder {
419 if loading_isq {
420 varbuilder.set_device(Device::Cpu)
421 } else {
422 varbuilder.set_device(self.nm_device.clone())
423 }
424 }
425 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
426 Some(&self.nm_device)
427 }
428 fn get_unique_devices(&self) -> Vec<Device> {
429 vec![self.nm_device.clone()]
430 }
431 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
432 if loading_isq {
433 x.to_device(&Device::Cpu)
434 } else {
435 x.to_device(&self.nm_device)
436 }
437 }
438 fn set_nm_device<'a>(
439 &self,
440 varbuilder: ShardedVarBuilder,
441 loading_isq: bool,
442 ) -> ShardedVarBuilder {
443 if loading_isq {
444 varbuilder.set_device(Device::Cpu)
445 } else {
446 varbuilder.set_device(self.nm_device.clone())
447 }
448 }
449 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
450 dtype
451 .try_into_dtype(&[&self.nm_device])
452 .map_err(candle_core::Error::msg)
453 }
454 fn num_device_mapping_layers(&self) -> usize {
455 self.model_layers
456 }
457 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
458 if let Some(comm) = &self.comm {
459 Ok(comm.clone())
460 } else {
461 let id = mistralrs_quant::Id::new();
462 Ok(Arc::new(mistralrs_quant::Comm::from_device(
463 id,
464 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
465 0,
466 1,
467 )?))
468 }
469 }
470}
471
472#[derive(Debug)]
473pub struct NcclPipelineParallelMapper {
475 mappings: Vec<(Arc<mistralrs_quant::Comm>, Device)>,
476 nm_device: Device,
477}
478
479impl DeviceMapper for NcclPipelineParallelMapper {
480 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
481 input.to_device(&self.mappings[layer].1)
482 }
483 fn set_device<'a>(
484 &self,
485 layer: usize,
486 varbuilder: ShardedVarBuilder,
487 loading_isq: bool,
488 ) -> ShardedVarBuilder {
489 if loading_isq {
490 return varbuilder;
491 }
492 varbuilder.set_device(self.mappings[layer].1.clone())
493 }
494 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
495 if loading_isq {
496 return Some(&self.nm_device);
497 }
498 self.mappings.get(layer).map(|(_, x)| x)
499 }
500 fn get_unique_devices(&self) -> Vec<Device> {
501 self.mappings
502 .iter()
503 .fold(Vec::new(), |mut acc, (_, device)| {
504 if !acc.iter().any(|d| d.same_device(device)) {
505 acc.push(device.clone());
506 }
507 acc
508 })
509 }
510 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
511 if loading_isq {
512 x.to_device(&Device::Cpu)
513 } else {
514 x.to_device(&self.nm_device)
515 }
516 }
517 fn set_nm_device<'a>(
518 &self,
519 varbuilder: ShardedVarBuilder,
520 loading_isq: bool,
521 ) -> ShardedVarBuilder {
522 if loading_isq {
523 varbuilder
524 } else {
525 varbuilder.set_device(self.nm_device.clone())
526 }
527 }
528 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
529 dtype
530 .try_into_dtype(&self.mappings.iter().map(|(_, x)| x).collect::<Vec<_>>())
531 .map_err(candle_core::Error::msg)
532 }
533 fn num_device_mapping_layers(&self) -> usize {
534 self.mappings.len()
535 }
536 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
537 Ok(self.mappings[layer_idx].0.clone())
538 }
539}
540
541pub fn get_all_similar_devices(base: &Device) -> Result<Vec<Device>> {
543 let mut devices = Vec::new();
544 match base {
545 Device::Cpu => return Ok(vec![Device::Cpu]),
546 Device::Cuda(_) => {
547 let mut ord = 0;
548 let DeviceLocation::Cuda { gpu_id: base_ord } = base.location() else {
549 candle_core::bail!("location and device do not match");
550 };
551 loop {
552 if base_ord == ord {
553 devices.push(base.clone());
554 ord += 1;
555 continue;
556 }
557 if let Ok(dev) = Device::new_cuda(ord) {
559 devices.push(dev);
560 ord += 1;
561 } else {
562 break;
563 }
564 }
565 }
566 #[cfg(not(feature = "metal"))]
567 Device::Metal(_) => {
568 candle_core::bail!("Not compiled with metal features, but have a metal device.");
569 }
570 #[cfg(feature = "metal")]
571 Device::Metal(_) => {
572 let total_ords = metal::Device::all().len();
573 let mut ord = 0;
574 let DeviceLocation::Metal { gpu_id: base_ord } = base.location() else {
575 candle_core::bail!("location and device do not match");
576 };
577 loop {
578 if base_ord == ord {
579 devices.push(base.clone());
580 ord += 1;
581 continue;
582 }
583 if total_ords == ord {
584 break;
585 }
586 if let Ok(dev) = Device::new_metal(ord) {
587 devices.push(dev);
588 ord += 1;
589 } else {
590 break;
591 }
592 }
593 }
594 }
595 Ok(devices)
596}