mistralrs_core/utils/
memory_usage.rs

1use candle_core::{Device, Result};
2use sysinfo::System;
3
4pub struct MemoryUsage;
5
6impl MemoryUsage {
7    /// Amount of available memory in bytes.
8    pub fn get_memory_available(&self, device: &Device) -> Result<usize> {
9        match device {
10            Device::Cpu => {
11                let mut sys = System::new_all();
12                sys.refresh_cpu();
13                Ok(usize::try_from(sys.available_memory())?)
14            }
15            #[cfg(feature = "cuda")]
16            Device::Cuda(dev) => {
17                use candle_core::cuda::cudarc::driver::result;
18                use candle_core::cuda_backend::WrapErr;
19
20                dev.cuda_stream().context().bind_to_thread().w()?;
21
22                let (free, _total) = result::mem_get_info().w()?;
23
24                Ok(free)
25            }
26            #[cfg(not(feature = "cuda"))]
27            Device::Cuda(_) => {
28                candle_core::bail!("Cannot get memory available for CUDA device")
29            }
30            #[cfg(feature = "metal")]
31            Device::Metal(dev) => {
32                let max = dev.recommended_max_working_set_size();
33                let alloc = dev.current_allocated_size();
34                let avail = max.saturating_sub(alloc);
35
36                #[allow(clippy::cast_possible_truncation)]
37                Ok(avail as usize)
38            }
39            #[cfg(not(feature = "metal"))]
40            Device::Metal(_) => {
41                candle_core::bail!("Cannot get memory available for Metal device")
42            }
43        }
44    }
45
46    /// Amount of total memory in bytes.
47    pub fn get_total_memory(&self, device: &Device) -> Result<usize> {
48        match device {
49            Device::Cpu => {
50                let mut sys = System::new_all();
51                sys.refresh_cpu();
52                Ok(usize::try_from(sys.total_memory())?)
53            }
54            #[cfg(feature = "cuda")]
55            Device::Cuda(dev) => {
56                use candle_core::cuda::cudarc::driver::result;
57                use candle_core::cuda_backend::WrapErr;
58
59                dev.cuda_stream().context().bind_to_thread().w()?;
60
61                let (_free, total) = result::mem_get_info().w()?;
62
63                Ok(total)
64            }
65            #[cfg(not(feature = "cuda"))]
66            Device::Cuda(_) => {
67                candle_core::bail!("Cannot get total memory for CUDA device")
68            }
69            #[cfg(feature = "metal")]
70            #[allow(clippy::cast_possible_truncation)]
71            Device::Metal(dev) => {
72                const SIZE_IN_MB: usize = 1024 * 1024;
73
74                // Get system RAM in MB
75                let system_ram_mb = {
76                    let mut sys = System::new_all();
77                    sys.refresh_cpu();
78                    usize::try_from(sys.total_memory())? / SIZE_IN_MB
79                };
80
81                // Check for Metal GPU wired limit
82                let metal_cap_mb = std::process::Command::new("sysctl")
83                    .arg("-n")
84                    .arg("iogpu.wired_limit_mb")
85                    .output()
86                    .ok()
87                    .and_then(|o| String::from_utf8(o.stdout).ok())
88                    .and_then(|s| s.trim().parse::<usize>().ok());
89
90                // Apply default cap based on system RAM if not set or 0
91                let default_cap = match system_ram_mb {
92                    x if x <= 36 * 1024 => (system_ram_mb * 2) / 3,
93                    x if x > 36 * 1024 => (system_ram_mb * 3) / 4,
94                    x => {
95                        return Err(candle_core::Error::Msg(format!(
96                            "Invalid system ram mb value {x}."
97                        )))
98                    }
99                };
100
101                let metal_cap_mb = match metal_cap_mb {
102                    Some(0) => default_cap,
103                    Some(x) => x,
104                    None => default_cap,
105                };
106
107                let device_max = dev.recommended_max_working_set_size() as usize;
108                let metal_cap_bytes = metal_cap_mb * SIZE_IN_MB;
109
110                Ok(device_max.min(metal_cap_bytes))
111            }
112            #[cfg(not(feature = "metal"))]
113            Device::Metal(_) => {
114                candle_core::bail!("Cannot get memory available for Metal device")
115            }
116        }
117    }
118}