mistralrs_core/utils/
memory_usage.rs

1use candle_core::{Device, Result};
2use sysinfo::System;
3
4pub struct MemoryUsage;
5
6impl MemoryUsage {
7    /// Amount of available memory in bytes.
8    pub fn get_memory_available(&self, device: &Device) -> Result<usize> {
9        match device {
10            Device::Cpu => {
11                let mut sys = System::new_all();
12                sys.refresh_cpu();
13                Ok(usize::try_from(sys.available_memory())?)
14            }
15            #[cfg(feature = "cuda")]
16            Device::Cuda(dev) => {
17                use candle_core::cuda::cudarc;
18                use candle_core::cuda_backend::WrapErr;
19                use candle_core::{backend::BackendDevice, DeviceLocation};
20
21                let DeviceLocation::Cuda { gpu_id } = dev.location() else {
22                    candle_core::bail!("device and location do match")
23                };
24
25                let original_ctx = dev.cu_primary_ctx();
26
27                let avail_mem = {
28                    #[allow(clippy::cast_possible_truncation)]
29                    let cu_device = cudarc::driver::result::device::get(gpu_id as i32).w()?;
30
31                    // primary context initialization, can fail with OOM
32                    let cu_primary_ctx =
33                        unsafe { cudarc::driver::result::primary_ctx::retain(cu_device) }.w()?;
34
35                    unsafe { cudarc::driver::result::ctx::set_current(cu_primary_ctx) }.unwrap();
36
37                    let res = cudarc::driver::result::mem_get_info().w()?.0;
38
39                    unsafe { cudarc::driver::result::primary_ctx::release(cu_device) }.unwrap();
40
41                    res
42                };
43
44                unsafe { cudarc::driver::result::ctx::set_current(*original_ctx) }.unwrap();
45
46                Ok(avail_mem)
47            }
48            #[cfg(not(feature = "cuda"))]
49            Device::Cuda(_) => {
50                candle_core::bail!("Cannot get memory available for CUDA device")
51            }
52            #[cfg(feature = "metal")]
53            Device::Metal(dev) => {
54                let max = dev.recommended_max_working_set_size();
55                let alloc = dev.current_allocated_size();
56                let avail = max.saturating_sub(alloc);
57
58                #[allow(clippy::cast_possible_truncation)]
59                Ok(avail as usize)
60            }
61            #[cfg(not(feature = "metal"))]
62            Device::Metal(_) => {
63                candle_core::bail!("Cannot get memory available for Metal device")
64            }
65        }
66    }
67
68    /// Amount of total memory in bytes.
69    pub fn get_total_memory(&self, device: &Device) -> Result<usize> {
70        match device {
71            Device::Cpu => {
72                let mut sys = System::new_all();
73                sys.refresh_cpu();
74                Ok(usize::try_from(sys.total_memory())?)
75            }
76            #[cfg(feature = "cuda")]
77            Device::Cuda(dev) => {
78                use candle_core::cuda::cudarc;
79                use candle_core::cuda_backend::WrapErr;
80                use candle_core::{backend::BackendDevice, DeviceLocation};
81
82                let DeviceLocation::Cuda { gpu_id } = dev.location() else {
83                    candle_core::bail!("device and location do match")
84                };
85
86                let original_ctx = dev.cu_primary_ctx();
87
88                let total_mem = {
89                    #[allow(clippy::cast_possible_truncation)]
90                    let cu_device = cudarc::driver::result::device::get(gpu_id as i32).w()?;
91
92                    // primary context initialization, can fail with OOM
93                    let cu_primary_ctx =
94                        unsafe { cudarc::driver::result::primary_ctx::retain(cu_device) }.w()?;
95
96                    unsafe { cudarc::driver::result::ctx::set_current(cu_primary_ctx) }.unwrap();
97
98                    let res = cudarc::driver::result::mem_get_info().w()?.1;
99
100                    unsafe { cudarc::driver::result::primary_ctx::release(cu_device) }.unwrap();
101
102                    res
103                };
104
105                unsafe { cudarc::driver::result::ctx::set_current(*original_ctx) }.unwrap();
106
107                Ok(total_mem)
108            }
109            #[cfg(not(feature = "cuda"))]
110            Device::Cuda(_) => {
111                candle_core::bail!("Cannot get total memory for CUDA device")
112            }
113            #[cfg(feature = "metal")]
114            #[allow(clippy::cast_possible_truncation)]
115            Device::Metal(dev) => Ok(dev.recommended_max_working_set_size() as usize),
116            #[cfg(not(feature = "metal"))]
117            Device::Metal(_) => {
118                candle_core::bail!("Cannot get memory available for Metal device")
119            }
120        }
121    }
122}