Sanity check to prevent performance regression (#3171)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -149,6 +149,7 @@ class Scheduler:
|
|||||||
if not self.spec_algorithm.is_none()
|
if not self.spec_algorithm.is_none()
|
||||||
else 1
|
else 1
|
||||||
)
|
)
|
||||||
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
||||||
|
|
||||||
# Distributed rank info
|
# Distributed rank info
|
||||||
self.dp_size = server_args.dp_size
|
self.dp_size = server_args.dp_size
|
||||||
@@ -831,10 +832,16 @@ class Scheduler:
|
|||||||
available_size = (
|
available_size = (
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
if available_size != self.max_total_num_tokens:
|
protected_size = self.tree_cache.protected_size()
|
||||||
|
memory_leak = available_size != (
|
||||||
|
self.max_total_num_tokens
|
||||||
|
if not self.enable_hierarchical_cache
|
||||||
|
else self.max_total_num_tokens - protected_size
|
||||||
|
)
|
||||||
|
if memory_leak:
|
||||||
msg = (
|
msg = (
|
||||||
"KV cache pool leak detected!"
|
"KV cache pool leak detected!"
|
||||||
f"{available_size=}, {self.max_total_num_tokens=}\n"
|
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||||
)
|
)
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
if crash_on_warnings():
|
if crash_on_warnings():
|
||||||
@@ -949,7 +956,14 @@ class Scheduler:
|
|||||||
res = adder.add_one_req(req)
|
res = adder.add_one_req(req)
|
||||||
if res != AddReqResult.CONTINUE:
|
if res != AddReqResult.CONTINUE:
|
||||||
if res == AddReqResult.NO_TOKEN:
|
if res == AddReqResult.NO_TOKEN:
|
||||||
self.batch_is_full = True
|
if self.enable_hierarchical_cache:
|
||||||
|
# Set batch_is_full after making sure there are requests that can be served
|
||||||
|
self.batch_is_full = len(adder.can_run_list) > 0 or (
|
||||||
|
self.running_batch is not None
|
||||||
|
and not self.running_batch.is_empty()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.batch_is_full = True
|
||||||
break
|
break
|
||||||
if self.server_args.prefill_only_one_req:
|
if self.server_args.prefill_only_one_req:
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
|
|||||||
def evictable_size(self):
|
def evictable_size(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def protected_size(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def total_size(self):
|
def total_size(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
|
|||||||
|
|
||||||
def evictable_size(self):
|
def evictable_size(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def protected_size(self):
|
||||||
|
return 0
|
||||||
|
|||||||
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class TreeNode:
|
class TreeNode:
|
||||||
def __init__(self):
|
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
def __init__(self, id: Optional[int] = None):
|
||||||
self.children = defaultdict(TreeNode)
|
self.children = defaultdict(TreeNode)
|
||||||
self.parent = None
|
self.parent = None
|
||||||
self.key = None
|
self.key = None
|
||||||
@@ -42,6 +45,23 @@ class TreeNode:
|
|||||||
self.lock_ref = 0
|
self.lock_ref = 0
|
||||||
self.last_access_time = time.time()
|
self.last_access_time = time.time()
|
||||||
|
|
||||||
|
self.hit_count = 0
|
||||||
|
# indicating the node is loading KV cache from host
|
||||||
|
self.loading = False
|
||||||
|
# store the host indices of KV cache
|
||||||
|
self.host_value = None
|
||||||
|
|
||||||
|
self.id = TreeNode.counter if id is None else id
|
||||||
|
TreeNode.counter += 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def evicted(self):
|
||||||
|
return self.value is None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def backuped(self):
|
||||||
|
return self.host_value is not None
|
||||||
|
|
||||||
def __lt__(self, other: "TreeNode"):
|
def __lt__(self, other: "TreeNode"):
|
||||||
return self.last_access_time < other.last_access_time
|
return self.last_access_time < other.last_access_time
|
||||||
|
|
||||||
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
self.root_node.value = []
|
self.root_node.value = []
|
||||||
self.root_node.lock_ref = 1
|
self.root_node.lock_ref = 1
|
||||||
self.evictable_size_ = 0
|
self.evictable_size_ = 0
|
||||||
|
self.protected_size_ = 0
|
||||||
|
|
||||||
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
||||||
"""Find the matching prefix from the radix tree.
|
"""Find the matching prefix from the radix tree.
|
||||||
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
while node != self.root_node:
|
while node != self.root_node:
|
||||||
if node.lock_ref == 0:
|
if node.lock_ref == 0:
|
||||||
self.evictable_size_ -= len(node.value)
|
self.evictable_size_ -= len(node.value)
|
||||||
|
self.protected_size_ += len(node.value)
|
||||||
delta -= len(node.value)
|
delta -= len(node.value)
|
||||||
node.lock_ref += 1
|
node.lock_ref += 1
|
||||||
node = node.parent
|
node = node.parent
|
||||||
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
while node != self.root_node:
|
while node != self.root_node:
|
||||||
if node.lock_ref == 1:
|
if node.lock_ref == 1:
|
||||||
self.evictable_size_ += len(node.value)
|
self.evictable_size_ += len(node.value)
|
||||||
|
self.protected_size_ -= len(node.value)
|
||||||
delta += len(node.value)
|
delta += len(node.value)
|
||||||
node.lock_ref -= 1
|
node.lock_ref -= 1
|
||||||
node = node.parent
|
node = node.parent
|
||||||
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
|
|||||||
def evictable_size(self):
|
def evictable_size(self):
|
||||||
return self.evictable_size_
|
return self.evictable_size_
|
||||||
|
|
||||||
|
def protected_size(self):
|
||||||
|
# protected size refers to the size of the cache that is locked
|
||||||
|
return self.protected_size_
|
||||||
|
|
||||||
##### Internal Helper Functions #####
|
##### Internal Helper Functions #####
|
||||||
|
|
||||||
def _match_prefix_helper(
|
def _match_prefix_helper(
|
||||||
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
|
|||||||
self.evictable_size_ -= len(node.key)
|
self.evictable_size_ -= len(node.key)
|
||||||
|
|
||||||
def _total_size_helper(self, node: TreeNode):
|
def _total_size_helper(self, node: TreeNode):
|
||||||
|
if node.evicted:
|
||||||
|
return 0
|
||||||
x = len(node.value)
|
x = len(node.value)
|
||||||
for child in node.children.values():
|
for child in node.children.values():
|
||||||
x += self._total_size_helper(child)
|
x += self._total_size_helper(child)
|
||||||
|
|||||||
@@ -163,6 +163,7 @@ class ServerArgs:
|
|||||||
# Custom logit processor
|
# Custom logit processor
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
tool_call_parser: str = None
|
tool_call_parser: str = None
|
||||||
|
enable_hierarchical_cache: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
@@ -892,6 +893,11 @@ class ServerArgs:
|
|||||||
default=ServerArgs.tool_call_parser,
|
default=ServerArgs.tool_call_parser,
|
||||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
|
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-hierarchical-cache",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable hierarchical cache",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user