Sanity check to prevent performance regression (#3171)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -149,6 +149,7 @@ class Scheduler:
|
||||
if not self.spec_algorithm.is_none()
|
||||
else 1
|
||||
)
|
||||
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
||||
|
||||
# Distributed rank info
|
||||
self.dp_size = server_args.dp_size
|
||||
@@ -831,10 +832,16 @@ class Scheduler:
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
if available_size != self.max_total_num_tokens:
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = available_size != (
|
||||
self.max_total_num_tokens
|
||||
if not self.enable_hierarchical_cache
|
||||
else self.max_total_num_tokens - protected_size
|
||||
)
|
||||
if memory_leak:
|
||||
msg = (
|
||||
"KV cache pool leak detected!"
|
||||
f"{available_size=}, {self.max_total_num_tokens=}\n"
|
||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||
)
|
||||
warnings.warn(msg)
|
||||
if crash_on_warnings():
|
||||
@@ -949,7 +956,14 @@ class Scheduler:
|
||||
res = adder.add_one_req(req)
|
||||
if res != AddReqResult.CONTINUE:
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
self.batch_is_full = True
|
||||
if self.enable_hierarchical_cache:
|
||||
# Set batch_is_full after making sure there are requests that can be served
|
||||
self.batch_is_full = len(adder.can_run_list) > 0 or (
|
||||
self.running_batch is not None
|
||||
and not self.running_batch.is_empty()
|
||||
)
|
||||
else:
|
||||
self.batch_is_full = True
|
||||
break
|
||||
if self.server_args.prefill_only_one_req:
|
||||
break
|
||||
|
||||
@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
|
||||
def evictable_size(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def protected_size(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def total_size(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
|
||||
|
||||
def evictable_size(self):
|
||||
return 0
|
||||
|
||||
def protected_size(self):
|
||||
return 0
|
||||
|
||||
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class TreeNode:
|
||||
def __init__(self):
|
||||
|
||||
counter = 0
|
||||
|
||||
def __init__(self, id: Optional[int] = None):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent = None
|
||||
self.key = None
|
||||
@@ -42,6 +45,23 @@ class TreeNode:
|
||||
self.lock_ref = 0
|
||||
self.last_access_time = time.time()
|
||||
|
||||
self.hit_count = 0
|
||||
# indicating the node is loading KV cache from host
|
||||
self.loading = False
|
||||
# store the host indices of KV cache
|
||||
self.host_value = None
|
||||
|
||||
self.id = TreeNode.counter if id is None else id
|
||||
TreeNode.counter += 1
|
||||
|
||||
@property
|
||||
def evicted(self):
|
||||
return self.value is None
|
||||
|
||||
@property
|
||||
def backuped(self):
|
||||
return self.host_value is not None
|
||||
|
||||
def __lt__(self, other: "TreeNode"):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.root_node.value = []
|
||||
self.root_node.lock_ref = 1
|
||||
self.evictable_size_ = 0
|
||||
self.protected_size_ = 0
|
||||
|
||||
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 0:
|
||||
self.evictable_size_ -= len(node.value)
|
||||
self.protected_size_ += len(node.value)
|
||||
delta -= len(node.value)
|
||||
node.lock_ref += 1
|
||||
node = node.parent
|
||||
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
|
||||
while node != self.root_node:
|
||||
if node.lock_ref == 1:
|
||||
self.evictable_size_ += len(node.value)
|
||||
self.protected_size_ -= len(node.value)
|
||||
delta += len(node.value)
|
||||
node.lock_ref -= 1
|
||||
node = node.parent
|
||||
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
|
||||
def protected_size(self):
|
||||
# protected size refers to the size of the cache that is locked
|
||||
return self.protected_size_
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(
|
||||
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
|
||||
self.evictable_size_ -= len(node.key)
|
||||
|
||||
def _total_size_helper(self, node: TreeNode):
|
||||
if node.evicted:
|
||||
return 0
|
||||
x = len(node.value)
|
||||
for child in node.children.values():
|
||||
x += self._total_size_helper(child)
|
||||
|
||||
@@ -163,6 +163,7 @@ class ServerArgs:
|
||||
# Custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
tool_call_parser: str = None
|
||||
enable_hierarchical_cache: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
@@ -892,6 +893,11 @@ class ServerArgs:
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-hierarchical-cache",
|
||||
action="store_true",
|
||||
help="Enable hierarchical cache",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user