mistralrs_core/utils/
memory_usage.rs1use candle_core::{Device, Result};
2use sysinfo::System;
3
4pub struct MemoryUsage;
5
6impl MemoryUsage {
7 pub fn get_memory_available(&self, device: &Device) -> Result<usize> {
9 match device {
10 Device::Cpu => {
11 let mut sys = System::new_all();
12 sys.refresh_cpu();
13 Ok(usize::try_from(sys.available_memory())?)
14 }
15 #[cfg(feature = "cuda")]
16 Device::Cuda(dev) => {
17 use candle_core::cuda::cudarc;
18 use candle_core::cuda_backend::WrapErr;
19 use candle_core::{backend::BackendDevice, DeviceLocation};
20
21 let DeviceLocation::Cuda { gpu_id } = dev.location() else {
22 candle_core::bail!("device and location do match")
23 };
24
25 let original_ctx = dev.cu_primary_ctx();
26
27 let avail_mem = {
28 #[allow(clippy::cast_possible_truncation)]
29 let cu_device = cudarc::driver::result::device::get(gpu_id as i32).w()?;
30
31 let cu_primary_ctx =
33 unsafe { cudarc::driver::result::primary_ctx::retain(cu_device) }.w()?;
34
35 unsafe { cudarc::driver::result::ctx::set_current(cu_primary_ctx) }.unwrap();
36
37 let res = cudarc::driver::result::mem_get_info().w()?.0;
38
39 unsafe { cudarc::driver::result::primary_ctx::release(cu_device) }.unwrap();
40
41 res
42 };
43
44 unsafe { cudarc::driver::result::ctx::set_current(*original_ctx) }.unwrap();
45
46 Ok(avail_mem)
47 }
48 #[cfg(not(feature = "cuda"))]
49 Device::Cuda(_) => {
50 candle_core::bail!("Cannot get memory available for CUDA device")
51 }
52 #[cfg(feature = "metal")]
53 Device::Metal(dev) => {
54 let max = dev.recommended_max_working_set_size();
55 let alloc = dev.current_allocated_size();
56 let avail = max.saturating_sub(alloc);
57
58 #[allow(clippy::cast_possible_truncation)]
59 Ok(avail as usize)
60 }
61 #[cfg(not(feature = "metal"))]
62 Device::Metal(_) => {
63 candle_core::bail!("Cannot get memory available for Metal device")
64 }
65 }
66 }
67
68 pub fn get_total_memory(&self, device: &Device) -> Result<usize> {
70 match device {
71 Device::Cpu => {
72 let mut sys = System::new_all();
73 sys.refresh_cpu();
74 Ok(usize::try_from(sys.total_memory())?)
75 }
76 #[cfg(feature = "cuda")]
77 Device::Cuda(dev) => {
78 use candle_core::cuda::cudarc;
79 use candle_core::cuda_backend::WrapErr;
80 use candle_core::{backend::BackendDevice, DeviceLocation};
81
82 let DeviceLocation::Cuda { gpu_id } = dev.location() else {
83 candle_core::bail!("device and location do match")
84 };
85
86 let original_ctx = dev.cu_primary_ctx();
87
88 let total_mem = {
89 #[allow(clippy::cast_possible_truncation)]
90 let cu_device = cudarc::driver::result::device::get(gpu_id as i32).w()?;
91
92 let cu_primary_ctx =
94 unsafe { cudarc::driver::result::primary_ctx::retain(cu_device) }.w()?;
95
96 unsafe { cudarc::driver::result::ctx::set_current(cu_primary_ctx) }.unwrap();
97
98 let res = cudarc::driver::result::mem_get_info().w()?.1;
99
100 unsafe { cudarc::driver::result::primary_ctx::release(cu_device) }.unwrap();
101
102 res
103 };
104
105 unsafe { cudarc::driver::result::ctx::set_current(*original_ctx) }.unwrap();
106
107 Ok(total_mem)
108 }
109 #[cfg(not(feature = "cuda"))]
110 Device::Cuda(_) => {
111 candle_core::bail!("Cannot get total memory for CUDA device")
112 }
113 #[cfg(feature = "metal")]
114 #[allow(clippy::cast_possible_truncation)]
115 Device::Metal(dev) => Ok(dev.recommended_max_working_set_size() as usize),
116 #[cfg(not(feature = "metal"))]
117 Device::Metal(_) => {
118 candle_core::bail!("Cannot get memory available for Metal device")
119 }
120 }
121 }
122}