mistralrs_core/utils/
memory_usage.rs1use candle_core::{Device, Result};
2use sysinfo::System;
3
4pub struct MemoryUsage;
5
6impl MemoryUsage {
7 #[allow(clippy::cast_possible_truncation)]
9 pub fn get_memory_available(&self, device: &Device) -> Result<usize> {
10 match device {
11 Device::Cpu => {
12 let mut sys = System::new_all();
13 sys.refresh_cpu();
14 Ok(usize::try_from(sys.available_memory())?)
15 }
16 #[cfg(feature = "cuda")]
17 Device::Cuda(dev) => {
18 use candle_core::cuda::cudarc::driver::{result, sys};
19 use candle_core::cuda_backend::WrapErr;
20
21 dev.cuda_stream().context().bind_to_thread().w()?;
22
23 let ordinal = dev.cuda_stream().context().ordinal();
25 let cu_device = result::device::get(ordinal as i32).w()?;
26 let is_integrated = unsafe {
27 result::device::get_attribute(
28 cu_device,
29 sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED,
30 )
31 .map(|v| v != 0)
32 .unwrap_or(false)
33 };
34
35 if is_integrated {
36 let mut sys = System::new_all();
39 sys.refresh_cpu();
40 let avail = usize::try_from(sys.available_memory())?;
41 Ok((avail * 3) / 4)
42 } else {
43 let (free, _total) = result::mem_get_info().w()?;
44 Ok(free)
45 }
46 }
47 #[cfg(not(feature = "cuda"))]
48 Device::Cuda(_) => {
49 candle_core::bail!("Cannot get memory available for CUDA device")
50 }
51 #[cfg(feature = "metal")]
52 Device::Metal(dev) => {
53 let max = dev.device().recommended_max_working_set_size();
54 let alloc = dev.current_allocated_size();
55 let avail = max.saturating_sub(alloc);
56
57 #[allow(clippy::cast_possible_truncation)]
58 Ok(avail)
59 }
60 #[cfg(not(feature = "metal"))]
61 Device::Metal(_) => {
62 candle_core::bail!("Cannot get memory available for Metal device")
63 }
64 }
65 }
66
67 #[allow(clippy::cast_possible_truncation)]
69 pub fn get_total_memory(&self, device: &Device) -> Result<usize> {
70 match device {
71 Device::Cpu => {
72 let mut sys = System::new_all();
73 sys.refresh_cpu();
74 Ok(usize::try_from(sys.total_memory())?)
75 }
76 #[cfg(feature = "cuda")]
77 Device::Cuda(dev) => {
78 use candle_core::cuda::cudarc::driver::{result, sys};
79 use candle_core::cuda_backend::WrapErr;
80
81 dev.cuda_stream().context().bind_to_thread().w()?;
82
83 let ordinal = dev.cuda_stream().context().ordinal();
85 let cu_device = result::device::get(ordinal as i32).w()?;
86 let is_integrated = unsafe {
87 result::device::get_attribute(
88 cu_device,
89 sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED,
90 )
91 .map(|v| v != 0)
92 .unwrap_or(false)
93 };
94
95 if is_integrated {
96 let mut sys = System::new_all();
99 sys.refresh_cpu();
100 let total = usize::try_from(sys.total_memory())?;
101 Ok((total * 3) / 4)
102 } else {
103 let (_free, total) = result::mem_get_info().w()?;
104 Ok(total)
105 }
106 }
107 #[cfg(not(feature = "cuda"))]
108 Device::Cuda(_) => {
109 candle_core::bail!("Cannot get total memory for CUDA device")
110 }
111 #[cfg(feature = "metal")]
112 #[allow(clippy::cast_possible_truncation)]
113 Device::Metal(dev) => {
114 const SIZE_IN_MB: usize = 1024 * 1024;
115
116 let system_ram_mb = {
118 let mut sys = System::new_all();
119 sys.refresh_cpu();
120 usize::try_from(sys.total_memory())? / SIZE_IN_MB
121 };
122
123 let metal_cap_mb = std::process::Command::new("sysctl")
125 .arg("-n")
126 .arg("iogpu.wired_limit_mb")
127 .output()
128 .ok()
129 .and_then(|o| String::from_utf8(o.stdout).ok())
130 .and_then(|s| s.trim().parse::<usize>().ok());
131
132 let default_cap = match system_ram_mb {
134 x if x <= 36 * 1024 => (system_ram_mb * 2) / 3,
135 x if x > 36 * 1024 => (system_ram_mb * 3) / 4,
136 x => {
137 return Err(candle_core::Error::Msg(format!(
138 "Invalid system ram mb value {x}."
139 )))
140 }
141 };
142
143 let metal_cap_mb = match metal_cap_mb {
144 Some(0) => default_cap,
145 Some(x) => x,
146 None => default_cap,
147 };
148
149 let device_max = dev.recommended_max_working_set_size();
150 let metal_cap_bytes = metal_cap_mb * SIZE_IN_MB;
151
152 Ok(device_max.min(metal_cap_bytes))
153 }
154 #[cfg(not(feature = "metal"))]
155 Device::Metal(_) => {
156 candle_core::bail!("Cannot get memory available for Metal device")
157 }
158 }
159 }
160}