mistralrs_core/utils/
memory_usage.rs

1use candle_core::{Device, Result};
2use sysinfo::System;
3
4pub struct MemoryUsage;
5
6impl MemoryUsage {
7    /// Amount of available memory in bytes.
8    #[allow(clippy::cast_possible_truncation)]
9    pub fn get_memory_available(&self, device: &Device) -> Result<usize> {
10        match device {
11            Device::Cpu => {
12                let mut sys = System::new_all();
13                sys.refresh_cpu();
14                Ok(usize::try_from(sys.available_memory())?)
15            }
16            #[cfg(feature = "cuda")]
17            Device::Cuda(dev) => {
18                use candle_core::cuda::cudarc::driver::{result, sys};
19                use candle_core::cuda_backend::WrapErr;
20
21                dev.cuda_stream().context().bind_to_thread().w()?;
22
23                // Check if this is an integrated GPU (unified memory, e.g., NVIDIA GB10)
24                let ordinal = dev.cuda_stream().context().ordinal();
25                let cu_device = result::device::get(ordinal as i32).w()?;
26                let is_integrated = unsafe {
27                    result::device::get_attribute(
28                        cu_device,
29                        sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED,
30                    )
31                    .map(|v| v != 0)
32                    .unwrap_or(false)
33                };
34
35                if is_integrated {
36                    // For integrated GPUs with unified memory, use system memory
37                    // Apply 3/4 fraction to leave room for OS and other processes
38                    let mut sys = System::new_all();
39                    sys.refresh_cpu();
40                    let avail = usize::try_from(sys.available_memory())?;
41                    Ok((avail * 3) / 4)
42                } else {
43                    let (free, _total) = result::mem_get_info().w()?;
44                    Ok(free)
45                }
46            }
47            #[cfg(not(feature = "cuda"))]
48            Device::Cuda(_) => {
49                candle_core::bail!("Cannot get memory available for CUDA device")
50            }
51            #[cfg(feature = "metal")]
52            Device::Metal(dev) => {
53                let max = dev.device().recommended_max_working_set_size();
54                let alloc = dev.current_allocated_size();
55                let avail = max.saturating_sub(alloc);
56
57                #[allow(clippy::cast_possible_truncation)]
58                Ok(avail)
59            }
60            #[cfg(not(feature = "metal"))]
61            Device::Metal(_) => {
62                candle_core::bail!("Cannot get memory available for Metal device")
63            }
64        }
65    }
66
67    /// Amount of total memory in bytes.
68    #[allow(clippy::cast_possible_truncation)]
69    pub fn get_total_memory(&self, device: &Device) -> Result<usize> {
70        match device {
71            Device::Cpu => {
72                let mut sys = System::new_all();
73                sys.refresh_cpu();
74                Ok(usize::try_from(sys.total_memory())?)
75            }
76            #[cfg(feature = "cuda")]
77            Device::Cuda(dev) => {
78                use candle_core::cuda::cudarc::driver::{result, sys};
79                use candle_core::cuda_backend::WrapErr;
80
81                dev.cuda_stream().context().bind_to_thread().w()?;
82
83                // Check if this is an integrated GPU (unified memory, e.g., NVIDIA GB10)
84                let ordinal = dev.cuda_stream().context().ordinal();
85                let cu_device = result::device::get(ordinal as i32).w()?;
86                let is_integrated = unsafe {
87                    result::device::get_attribute(
88                        cu_device,
89                        sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED,
90                    )
91                    .map(|v| v != 0)
92                    .unwrap_or(false)
93                };
94
95                if is_integrated {
96                    // For integrated GPUs with unified memory, use system total memory
97                    // Apply 3/4 fraction similar to Metal's approach
98                    let mut sys = System::new_all();
99                    sys.refresh_cpu();
100                    let total = usize::try_from(sys.total_memory())?;
101                    Ok((total * 3) / 4)
102                } else {
103                    let (_free, total) = result::mem_get_info().w()?;
104                    Ok(total)
105                }
106            }
107            #[cfg(not(feature = "cuda"))]
108            Device::Cuda(_) => {
109                candle_core::bail!("Cannot get total memory for CUDA device")
110            }
111            #[cfg(feature = "metal")]
112            #[allow(clippy::cast_possible_truncation)]
113            Device::Metal(dev) => {
114                const SIZE_IN_MB: usize = 1024 * 1024;
115
116                // Get system RAM in MB
117                let system_ram_mb = {
118                    let mut sys = System::new_all();
119                    sys.refresh_cpu();
120                    usize::try_from(sys.total_memory())? / SIZE_IN_MB
121                };
122
123                // Check for Metal GPU wired limit
124                let metal_cap_mb = std::process::Command::new("sysctl")
125                    .arg("-n")
126                    .arg("iogpu.wired_limit_mb")
127                    .output()
128                    .ok()
129                    .and_then(|o| String::from_utf8(o.stdout).ok())
130                    .and_then(|s| s.trim().parse::<usize>().ok());
131
132                // Apply default cap based on system RAM if not set or 0
133                let default_cap = match system_ram_mb {
134                    x if x <= 36 * 1024 => (system_ram_mb * 2) / 3,
135                    x if x > 36 * 1024 => (system_ram_mb * 3) / 4,
136                    x => {
137                        return Err(candle_core::Error::Msg(format!(
138                            "Invalid system ram mb value {x}."
139                        )))
140                    }
141                };
142
143                let metal_cap_mb = match metal_cap_mb {
144                    Some(0) => default_cap,
145                    Some(x) => x,
146                    None => default_cap,
147                };
148
149                let device_max = dev.recommended_max_working_set_size();
150                let metal_cap_bytes = metal_cap_mb * SIZE_IN_MB;
151
152                Ok(device_max.min(metal_cap_bytes))
153            }
154            #[cfg(not(feature = "metal"))]
155            Device::Metal(_) => {
156                candle_core::bail!("Cannot get memory available for Metal device")
157            }
158        }
159    }
160}