init
This commit is contained in:
127
vllm/core/evictor_v2.py
Normal file
127
vllm/core/evictor_v2.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import enum
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import OrderedDict, Tuple
|
||||
|
||||
|
||||
class EvictionPolicy(enum.Enum):
|
||||
"""Enum for eviction policy used by make_evictor to instantiate the correct
|
||||
Evictor subclass.
|
||||
"""
|
||||
LRU = enum.auto()
|
||||
|
||||
|
||||
class Evictor(ABC):
|
||||
"""The Evictor subclasses should be used by the BlockAllocator class to
|
||||
handle eviction of freed PhysicalTokenBlocks.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, block_id: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self) -> Tuple[int, int]:
|
||||
"""Runs the eviction algorithm and returns the evicted block's
|
||||
content hash along with physical block id along with physical block id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
||||
last_accessed: float):
|
||||
"""Adds block to the evictor, making it a candidate for eviction"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, block_id: int, last_accessed: float):
|
||||
"""Update corresponding block's access time in metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, block_id: int):
|
||||
"""Remove a given block id from the cache."""
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def num_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
class BlockMetaData():
|
||||
"""Data structure for storing key data describe cached block, so that
|
||||
evitor could use to make its decision which one to choose for eviction
|
||||
|
||||
Here we use physical block id as the dict key, as there maybe several
|
||||
blocks with the same content hash, but their physical id is unique.
|
||||
"""
|
||||
|
||||
def __init__(self, content_hash: int, num_hashed_tokens: int,
|
||||
last_accessed: float):
|
||||
self.content_hash = content_hash
|
||||
self.num_hashed_tokens = num_hashed_tokens
|
||||
self.last_accessed = last_accessed
|
||||
|
||||
|
||||
class LRUEvictor(Evictor):
|
||||
"""Evicts in a least-recently-used order using the last_accessed timestamp
|
||||
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
|
||||
the same last_accessed time, then the one with the largest num_hashed_tokens
|
||||
will be evicted. If two blocks each have the lowest last_accessed time and
|
||||
highest num_hashed_tokens value, then one will be chose arbitrarily
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict()
|
||||
|
||||
def __contains__(self, block_id: int) -> bool:
|
||||
return block_id in self.free_table
|
||||
|
||||
def evict(self) -> Tuple[int, int]:
|
||||
if len(self.free_table) == 0:
|
||||
raise ValueError("No usable cache memory left")
|
||||
|
||||
evicted_block = next(iter(self.free_table.values()))
|
||||
evicted_block_id = next(iter(self.free_table.keys()))
|
||||
# The blocks with the lowest timestamps should be placed consecutively
|
||||
# at the start of OrderedDict. Loop through all these blocks to
|
||||
# find the one with maximum number of hashed tokens.
|
||||
for _id, block in self.free_table.items():
|
||||
if evicted_block.last_accessed > block.last_accessed or (
|
||||
evicted_block.last_accessed == block.last_accessed and
|
||||
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
|
||||
evicted_block = block
|
||||
evicted_block_id = _id
|
||||
|
||||
self.free_table.pop(evicted_block_id)
|
||||
|
||||
return evicted_block_id, evicted_block.content_hash
|
||||
|
||||
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
||||
last_accessed: float):
|
||||
self.free_table[block_id] = BlockMetaData(content_hash,
|
||||
num_hashed_tokens,
|
||||
last_accessed)
|
||||
|
||||
def update(self, block_id: int, last_accessed: float):
|
||||
self.free_table[block_id].last_accessed = last_accessed
|
||||
|
||||
def remove(self, block_id: int):
|
||||
if block_id not in self.free_table:
|
||||
raise ValueError(
|
||||
"Attempting to remove block that's not in the evictor")
|
||||
self.free_table.pop(block_id)
|
||||
|
||||
@property
|
||||
def num_blocks(self) -> int:
|
||||
return len(self.free_table)
|
||||
|
||||
|
||||
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
|
||||
if eviction_policy == EvictionPolicy.LRU:
|
||||
return LRUEvictor()
|
||||
else:
|
||||
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
|
||||
Reference in New Issue
Block a user