mistralrs_core/utils/
memory_usage.rs

1use candle_core::{Device, Result};
2use sysinfo::System;
3
4pub struct MemoryUsage;
5
6impl MemoryUsage {
7    /// Amount of available memory in bytes.
8    pub fn get_memory_available(&self, device: &Device) -> Result<usize> {
9        match device {
10            Device::Cpu => {
11                let mut sys = System::new_all();
12                sys.refresh_cpu();
13                Ok(usize::try_from(sys.available_memory())?)
14            }
15            #[cfg(feature = "cuda")]
16            Device::Cuda(dev) => {
17                use candle_core::cuda::cudarc;
18                use candle_core::cuda_backend::WrapErr;
19                use candle_core::{backend::BackendDevice, DeviceLocation};
20
21                let DeviceLocation::Cuda { gpu_id } = dev.location() else {
22                    candle_core::bail!("device and location do match")
23                };
24
25                let original_ctx = dev.cu_primary_ctx();
26
27                let avail_mem = {
28                    #[allow(clippy::cast_possible_truncation)]
29                    let cu_device = cudarc::driver::result::device::get(gpu_id as i32).w()?;
30
31                    // primary context initialization, can fail with OOM
32                    let cu_primary_ctx =
33                        unsafe { cudarc::driver::result::primary_ctx::retain(cu_device) }.w()?;
34
35                    unsafe { cudarc::driver::result::ctx::set_current(cu_primary_ctx) }.unwrap();
36
37                    let res = cudarc::driver::result::mem_get_info().w()?.0;
38
39                    unsafe { cudarc::driver::result::primary_ctx::release(cu_device) }.unwrap();
40
41                    res
42                };
43
44                unsafe { cudarc::driver::result::ctx::set_current(*original_ctx) }.unwrap();
45
46                Ok(avail_mem)
47            }
48            #[cfg(not(feature = "cuda"))]
49            Device::Cuda(_) => {
50                candle_core::bail!("Cannot get memory available for CUDA device")
51            }
52            #[cfg(feature = "metal")]
53            Device::Metal(dev) => {
54                let max = dev.recommended_max_working_set_size();
55                let alloc = dev.current_allocated_size();
56                let avail = max.saturating_sub(alloc);
57
58                #[allow(clippy::cast_possible_truncation)]
59                Ok(avail as usize)
60            }
61            #[cfg(not(feature = "metal"))]
62            Device::Metal(_) => {
63                candle_core::bail!("Cannot get memory available for Metal device")
64            }
65        }
66    }
67
68    /// Amount of total memory in bytes.
69    pub fn get_total_memory(&self, device: &Device) -> Result<usize> {
70        match device {
71            Device::Cpu => {
72                let mut sys = System::new_all();
73                sys.refresh_cpu();
74                Ok(usize::try_from(sys.total_memory())?)
75            }
76            #[cfg(feature = "cuda")]
77            Device::Cuda(dev) => {
78                use candle_core::cuda::cudarc;
79                use candle_core::cuda_backend::WrapErr;
80                use candle_core::{backend::BackendDevice, DeviceLocation};
81
82                let DeviceLocation::Cuda { gpu_id } = dev.location() else {
83                    candle_core::bail!("device and location do match")
84                };
85
86                let original_ctx = dev.cu_primary_ctx();
87
88                let total_mem = {
89                    #[allow(clippy::cast_possible_truncation)]
90                    let cu_device = cudarc::driver::result::device::get(gpu_id as i32).w()?;
91
92                    // primary context initialization, can fail with OOM
93                    let cu_primary_ctx =
94                        unsafe { cudarc::driver::result::primary_ctx::retain(cu_device) }.w()?;
95
96                    unsafe { cudarc::driver::result::ctx::set_current(cu_primary_ctx) }.unwrap();
97
98                    let res = cudarc::driver::result::mem_get_info().w()?.1;
99
100                    unsafe { cudarc::driver::result::primary_ctx::release(cu_device) }.unwrap();
101
102                    res
103                };
104
105                unsafe { cudarc::driver::result::ctx::set_current(*original_ctx) }.unwrap();
106
107                Ok(total_mem)
108            }
109            #[cfg(not(feature = "cuda"))]
110            Device::Cuda(_) => {
111                candle_core::bail!("Cannot get total memory for CUDA device")
112            }
113            #[cfg(feature = "metal")]
114            #[allow(clippy::cast_possible_truncation)]
115            Device::Metal(dev) => {
116                const SIZE_IN_MB: usize = 1024 * 1024;
117
118                // Get system RAM in MB
119                let system_ram_mb = {
120                    let mut sys = System::new_all();
121                    sys.refresh_cpu();
122                    usize::try_from(sys.total_memory())? / SIZE_IN_MB
123                };
124
125                // Check for Metal GPU wired limit
126                let metal_cap_mb = std::process::Command::new("sysctl")
127                    .arg("-n")
128                    .arg("iogpu.wired_limit_mb")
129                    .output()
130                    .ok()
131                    .and_then(|o| String::from_utf8(o.stdout).ok())
132                    .and_then(|s| s.trim().parse::<usize>().ok());
133
134                // Apply default cap based on system RAM if not set or 0
135                let default_cap = match system_ram_mb {
136                    x if x <= 36 * 1024 => (system_ram_mb * 2) / 3,
137                    x if x > 36 * 1024 => (system_ram_mb * 3) / 4,
138                    x => {
139                        return Err(candle_core::Error::Msg(format!(
140                            "Invalid system ram mb value {x}."
141                        )))
142                    }
143                };
144
145                let metal_cap_mb = match metal_cap_mb {
146                    Some(x) if x == 0 => default_cap,
147                    Some(x) => x,
148                    None => default_cap,
149                };
150
151                let device_max = dev.recommended_max_working_set_size() as usize;
152                let metal_cap_bytes = metal_cap_mb * SIZE_IN_MB;
153
154                Ok(device_max.min(metal_cap_bytes))
155            }
156            #[cfg(not(feature = "metal"))]
157            Device::Metal(_) => {
158                candle_core::bail!("Cannot get memory available for Metal device")
159            }
160        }
161    }
162}