Sanity check to prevent performance regression (#3171)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
|
||||
def evictable_size(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def protected_size(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def total_size(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
|
||||
|
||||
def evictable_size(self):
|
||||
return 0
|
||||
|
||||
def protected_size(self):
|
||||
return 0
|
||||
|
||||
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class TreeNode:
|
||||
def __init__(self):
|
||||
|
||||
counter = 0
|
||||
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent = None
|
||||
self.key = None
|
||||
@@ -42,6 +45,23 @@ class TreeNode:
|
||||
self.lock_ref = 0
|
||||
self.last_access_time = time.time()
|
||||
|
||||
self.hit_count = 0
|
||||
# indicating the node is loading KV cache from host
|
||||
self.loading = False
|
||||
# store the host indices of KV cache
|
||||
self.host_value = None
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
|
||||
@property
|
||||
def evicted(self):
|
||||
return self.value is None
|
||||
|
||||
@property
|
||||
def backuped(self):
|
||||
return self.host_value is not None
|
||||
|
||||
def __lt__(self, other: "TreeNode"):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.root_node.value = []
|
||||
self.root_node.lock_ref = 1
|
||||
self.evictable_size_ = 0
|
||||
self.protected_size_ = 0
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 0:
|
||||
self.evictable_size_ -= len(node.value)
|
||||
self.protected_size_ += len(node.value)
|
||||
delta -= len(node.value)
|
||||
node.lock_ref += 1
|
||||
node = node.parent
|
||||
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 1:
|
||||
self.evictable_size_ += len(node.value)
|
||||
self.protected_size_ -= len(node.value)
|
||||
delta += len(node.value)
|
||||
node.lock_ref -= 1
|
||||
node = node.parent
|
||||
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
|
||||
def protected_size(self):
|
||||
# protected size refers to the size of the cache that is locked
|
||||
return self.protected_size_
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(
|
||||
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
|
||||
self.evictable_size_ -= len(node.key)
|
||||
|
||||
def _total_size_helper(self, node: TreeNode):
|
||||
if node.evicted:
|
||||
return 0
|
||||
x = len(node.value)
|
||||
for child in node.children.values():
|
||||
x += self._total_size_helper(child)
|
||||
|
||||
Reference in New Issue
Block a user