mistralrs_core/dummy_paged_attention/
block_engine.rsuse std::{
collections::{hash_map::Entry, HashMap},
hash::Hash,
marker::PhantomData,
ops::Deref,
sync::{Arc, Mutex, MutexGuard},
};
use super::block_engine_sequence::BlockEngineSequence;
pub struct LogicalTokenBlock {
tokens: Vec<usize>,
block_size: usize,
num_tokens: usize,
}
impl LogicalTokenBlock {
pub fn new(block_size: usize) -> Self {
Self {
tokens: [0].repeat(block_size),
block_size,
num_tokens: 0,
}
}
pub fn is_full(&self) -> bool {
self.num_tokens == self.block_size
}
pub fn is_empty(&self) -> bool {
self.num_tokens == 0
}
pub fn append_token_id(&mut self, token: usize) {
assert!(!self.is_full());
self.tokens[self.num_tokens] = token;
self.num_tokens += 1;
}
pub fn pop_token(&mut self) {
assert_ne!(self.num_tokens, 0);
self.tokens.pop();
self.num_tokens -= 1;
}
}
#[derive(Hash, PartialEq, Eq)]
pub struct _PhysicalTokenBlock {
pub block_id: usize,
block_size: usize,
refcount: usize,
is_gpu: bool,
}
pub struct PhysicalTokenBlock(pub Mutex<_PhysicalTokenBlock>);
impl PhysicalTokenBlock {
pub fn deref_mut(&self) -> MutexGuard<'_, _PhysicalTokenBlock> {
loop {
if let Ok(v) = self.0.try_lock() {
return v;
}
}
}
}
impl PartialEq for PhysicalTokenBlock {
fn eq(&self, other: &Self) -> bool {
*self.deref_mut() == *other.deref_mut()
}
}
impl Hash for PhysicalTokenBlock {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.deref_mut().hash(state)
}
}
impl Eq for PhysicalTokenBlock {}
type BlockTable = Vec<Arc<PhysicalTokenBlock>>;
struct GPUAllocator;
struct CPUAllocator;
struct GPUAllocatorWrapper(usize);
impl Deref for GPUAllocatorWrapper {
type Target = usize;
fn deref(&self) -> &Self::Target {
&self.0
}
}
struct Allocator<T> {
free_blocks: BlockTable,
_ghost: PhantomData<T>,
}
impl<T> Allocator<T> {
fn allocate(&mut self) -> Arc<PhysicalTokenBlock> {
let block = self.free_blocks.pop().unwrap();
block.deref_mut().refcount = 1;
block
}
fn free_block(&mut self, block: Arc<PhysicalTokenBlock>) {
if block.deref_mut().refcount == 0 {
panic!(
"PhysicalTokenBlock with id {} experienced a double free!",
block.deref_mut().block_id
);
}
block.deref_mut().refcount -= 1;
if block.deref_mut().refcount == 0 {
self.free_blocks.push(block);
}
}
}
impl Allocator<GPUAllocator> {
fn new(block_size: usize, num_blocks: usize) -> Self {
let mut free_blocks = Vec::new();
for id in 0..num_blocks {
free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
_PhysicalTokenBlock {
block_id: id,
block_size,
refcount: 0,
is_gpu: true,
},
))))
}
Allocator {
free_blocks,
_ghost: PhantomData,
}
}
fn get_num_free_blocks(&self) -> GPUAllocatorWrapper {
GPUAllocatorWrapper(self.free_blocks.len())
}
}
impl Allocator<CPUAllocator> {
fn new(block_size: usize, num_blocks: usize) -> Self {
let mut free_blocks = Vec::new();
for id in 0..num_blocks {
free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new(
_PhysicalTokenBlock {
block_id: id,
block_size,
refcount: 0,
is_gpu: true,
},
))))
}
Allocator {
free_blocks,
_ghost: PhantomData,
}
}
}
pub enum AllocStatus {
Ok,
Later,
Impossible,
}
type SeqID = usize;
pub struct BlockEngine {
num_gpu_blocks: usize,
gpu_allocator: Allocator<GPUAllocator>,
cpu_allocator: Allocator<CPUAllocator>,
pub block_tables: HashMap<SeqID, BlockTable>,
}
pub type BlockTables = HashMap<usize, BlockTable>;
impl BlockEngine {
#[must_use]
pub fn new(block_size: usize, num_gpu_blocks: usize, num_cpu_blocks: usize) -> Self {
Self {
num_gpu_blocks,
gpu_allocator: Allocator::<GPUAllocator>::new(block_size, num_gpu_blocks),
cpu_allocator: Allocator::<CPUAllocator>::new(block_size, num_cpu_blocks),
block_tables: HashMap::new(),
}
}
pub fn can_allocate(&self, seq: &impl BlockEngineSequence) -> AllocStatus {
let num_required_blocks = seq.get_logical_token_blocks();
let num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks();
if self.num_gpu_blocks > *num_free_gpu_blocks + num_required_blocks {
AllocStatus::Later
} else if self.num_gpu_blocks < num_required_blocks {
AllocStatus::Impossible
} else {
AllocStatus::Ok
}
}
pub fn allocate(&mut self, seq: &impl BlockEngineSequence) {
let mut block_table = Vec::new();
for _logcical_idx in 0..seq.get_logical_token_blocks() {
block_table.push(self.gpu_allocator.allocate());
}
self.block_tables.insert(seq.get_id(), block_table.clone());
}
pub fn can_append_token_to_seq(&self, seq: &impl BlockEngineSequence) -> bool {
let free_blocks = self.gpu_allocator.get_num_free_blocks();
seq.blocks_to_add_new_tok() <= *free_blocks
}
pub fn free_sequence(&mut self, id: usize) {
if let Some(block_table) = self.block_tables.get(&id) {
for block in block_table {
if block.deref_mut().is_gpu {
self.gpu_allocator.free_block(block.clone())
} else {
self.cpu_allocator.free_block(block.clone())
}
}
self.block_tables.remove(&id);
}
}
#[allow(dead_code)]
pub fn can_swap_out_seq(&self, seq: &impl BlockEngineSequence) -> bool {
let blocks_required: usize = self
.block_tables
.iter()
.filter(|(id, _)| seq.get_id() == **id)
.map(|(_, table)| table.len())
.sum();
blocks_required <= self.cpu_allocator.free_blocks.len()
}
#[allow(dead_code)]
pub fn swap_out(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
let mut new_mapping = HashMap::new();
let seq_id = seq.get_id();
let mut new_block_table = Vec::new();
let block_table = self.block_tables.get(&seq_id).unwrap();
for gpu_block in block_table {
let cpu_block =
if let Entry::Vacant(e) = new_mapping.entry(gpu_block.deref_mut().block_id) {
let cpu_block = self.cpu_allocator.allocate();
e.insert(cpu_block.clone());
cpu_block
} else {
let cpu_block = new_mapping
.get(&gpu_block.deref_mut().block_id)
.unwrap()
.clone();
cpu_block.deref_mut().refcount += 1;
cpu_block
};
new_block_table.push(cpu_block);
self.gpu_allocator.free_block(gpu_block.clone());
}
self.block_tables.insert(seq_id, new_block_table);
new_mapping
.iter()
.map(|(k, v)| (*k, v.deref_mut().block_id))
.collect::<HashMap<_, _>>()
}
pub fn append_token_slot_to_seq(
&mut self,
sequence: &impl BlockEngineSequence,
) -> Option<(usize, usize)> {
let table = self.block_tables.get_mut(&sequence.get_id())?;
match sequence.blocks_to_add_new_tok() {
1 => {
table.push(self.gpu_allocator.allocate());
None
}
0 => {
let last_block = table.last_mut().unwrap();
assert!(last_block.deref_mut().is_gpu);
if last_block.deref_mut().refcount == 1 {
None
} else {
let new_block = self.gpu_allocator.allocate();
self.gpu_allocator.free_block(last_block.clone());
let old_number = last_block.deref_mut().block_id;
let new_number = new_block.deref_mut().block_id;
*last_block = new_block;
Some((old_number, new_number))
}
}
_ => {
unreachable!()
}
}
}
pub fn can_swap_in_seq(&self, seq: &impl BlockEngineSequence) -> bool {
let blocks_required: usize = self
.block_tables
.iter()
.filter(|(id, _)| seq.get_id() == **id)
.map(|(_, table)| table.len())
.sum();
blocks_required <= self.gpu_allocator.free_blocks.len()
}
pub fn swap_in(&mut self, seq: &impl BlockEngineSequence) -> HashMap<usize, usize> {
let mut new_mapping = HashMap::new();
let seq_id = seq.get_id();
let mut new_block_table = Vec::new();
let block_table = self.block_tables.get(&seq_id).unwrap();
for cpu_block in block_table {
let gpu_block =
if let Entry::Vacant(e) = new_mapping.entry(cpu_block.deref_mut().block_id) {
let gpu_block = self.cpu_allocator.allocate();
e.insert(gpu_block.clone());
gpu_block
} else {
let gpu_block = new_mapping
.get(&cpu_block.deref_mut().block_id)
.unwrap()
.clone();
gpu_block.deref_mut().refcount += 1;
gpu_block
};
new_block_table.push(gpu_block);
self.gpu_allocator.free_block(cpu_block.clone());
}
self.block_tables.insert(seq_id, new_block_table);
new_mapping
.iter()
.map(|(k, v)| (*k, v.deref_mut().block_id))
.collect::<HashMap<_, _>>()
}
}