mistralrs_core/paged_attention/
mod.rs1mod block_engine;
4mod block_engine_sequence;
5mod cache_engine;
9mod config;
10mod layers;
11mod scheduler;
12pub const _PAD_SLOT_ID: i64 = -1;
13
14pub use block_engine::{BlockEngine, BlockTables, LogicalTokenBlock, PhysicalTokenBlock};
15pub use block_engine_sequence::BlockEngineSequence;
16pub use cache_engine::{CacheConfig, CacheEngine, PagedCacheType};
17use candle_core::{DType, Device};
18pub use config::{ModelConfigLike, ModelConfigMetadata};
19pub use layers::PagedAttention;
20pub use scheduler::{
21 PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput,
22};
23
24use crate::MemoryUsage;
25use tracing::{info, warn};
26
27pub const DEFAULT_PAGED_ATTENTION_BLOCK_SIZE: usize = 32;
28
29#[derive(Clone, Copy)]
31pub struct PagedAttentionConfig {
32 pub(crate) block_size: Option<usize>,
33 pub(crate) mem_gpu: MemoryGpuConfig,
34 pub(crate) cache_type: PagedCacheType,
35}
36
37impl PagedAttentionConfig {
38 pub fn new(
39 block_size: Option<usize>,
40 mem_gpu: MemoryGpuConfig,
41 cache_type: PagedCacheType,
42 ) -> anyhow::Result<Self> {
43 Ok(Self {
44 block_size,
45 mem_gpu,
46 cache_type,
47 })
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq)]
52pub enum AttentionImplementation {
53 Eager,
54 PagedAttention,
55}
56
57#[derive(Clone, Copy)]
58#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
59pub enum MemoryGpuConfig {
60 MbAmount(usize),
61 Utilization(f32),
62 ContextSize(usize),
63}
64
65const SUPPORTED_BLOCK_SIZE: &[usize] = &[8, 16, 32];
67
68const SIZE_IN_MB: usize = 1024 * 1024;
69
70macro_rules! mb_to_blocks {
71 ($mb_size:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
72 $mb_size
73 / $dtype_size
74 / $block_size
75 / $config.num_kv_heads()
76 / ($config.k_head_dim().max($config.v_head_dim()))
77 / $config.num_layers()
78 / 2
79 };
80}
81
82macro_rules! ctxt_to_blocks {
83 ($context_len:expr, $dtype_size:expr, $block_size:expr, $config:expr) => {
84 $context_len
85 * $dtype_size
86 * $config.num_kv_heads()
87 * ($config.k_head_dim().max($config.v_head_dim()))
88 * $config.num_layers()
89 * 2
90 };
91}
92
93#[allow(clippy::too_many_arguments)]
95pub fn calculate_cache_config(
96 mem_gpu: MemoryGpuConfig,
97 block_size: Option<usize>,
98 dtype: DType,
99 cache_type: PagedCacheType,
100 config: &dyn ModelConfigLike,
101 device: &Device,
102 layer_devices: &[Option<Device>],
103 silent: bool,
104) -> anyhow::Result<CacheConfig> {
105 let block_size = block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE);
106 if !SUPPORTED_BLOCK_SIZE.contains(&block_size) {
107 anyhow::bail!("Block size must be in {SUPPORTED_BLOCK_SIZE:?}, got {block_size}");
108 }
109 let dtype = cache_type.to_dtype(dtype);
110 let dtype_size = dtype.size_in_bytes();
111
112 let mut min_mem_gpu = usize::MAX;
113 for dev in layer_devices {
114 let device = dev.as_ref().unwrap_or(device);
115
116 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
117 let mem_gpu = match mem_gpu {
118 MemoryGpuConfig::MbAmount(v) => v,
119 MemoryGpuConfig::Utilization(f) => {
120 let free = MemoryUsage.get_memory_available(device)? as f32 / SIZE_IN_MB as f32;
121 let total = MemoryUsage.get_total_memory(device)? as f32 / SIZE_IN_MB as f32;
122 let used = total - free;
123 (total * f - used) as usize
124 }
125 MemoryGpuConfig::ContextSize(toks) => {
126 ctxt_to_blocks!(toks, dtype_size, block_size, config) / SIZE_IN_MB
127 }
128 };
129 min_mem_gpu = min_mem_gpu.min(mem_gpu);
130 }
131
132 let mem_gpu = if matches!(device, Device::Metal(_)) {
140 let metal_cap_mb = MemoryUsage.get_total_memory(device)? / SIZE_IN_MB;
141
142 info!("Metal GPU wired limit is {metal_cap_mb} MB.");
143
144 if min_mem_gpu > metal_cap_mb {
145 if !silent {
146 warn!(
147 "Capping Metal GPU memory allocation from {} MB to {} MB (limited by iogpu.wired_limit_mb). \
148To raise this cap run: `sudo sysctl -w iogpu.wired_limit_mb=<desired_mb>`.",
149 min_mem_gpu,
150 metal_cap_mb
151 );
152 }
153 metal_cap_mb
154 } else {
155 min_mem_gpu
156 }
157 } else {
158 min_mem_gpu
159 };
160
161 let num_gpu_blocks = mb_to_blocks!(mem_gpu * SIZE_IN_MB, dtype_size, block_size, config);
162 if num_gpu_blocks == 0 {
163 anyhow::bail!("Num GPU blocks is 0. This means there is not enough memory. Either reduce the memory amount/utilization/context size or disable PagedAttention.");
164 }
165
166 if !silent {
167 info!("Allocating {mem_gpu} MB for PagedAttention KV cache per GPU");
168 info!("PagedAttention KV cache type is {dtype:?}");
169 info!("Using PagedAttention with block size {block_size} and {num_gpu_blocks} GPU blocks: available context length is {} tokens", num_gpu_blocks*block_size);
170 }
171 Ok(CacheConfig {
172 block_size,
173 num_gpu_blocks,
174 cache_type,
175 })
176}