Hierarchical Caching Refactoring and Fixing TP issue (#4082)

This commit is contained in:
Zhiqiang Xie
2025-03-12 11:22:35 -07:00
committed by GitHub
parent 01090e8ac3
commit 10b544ae9b
6 changed files with 194 additions and 56 deletions

View File

@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
class SchedulePolicy:
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
def __init__(self, policy: str, tree_cache: BasePrefixCache):
def __init__(
self,
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool = False,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache
self.enable_hierarchical_cache = enable_hierarchical_cache
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
@@ -149,9 +155,14 @@ class SchedulePolicy:
prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
if self.enable_hierarchical_cache:
r.prefix_indices, r.last_node, r.last_node_global = (
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
)
else:
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
@@ -428,7 +439,9 @@ class PrefillAdder:
return self.budget_state()
def add_one_req(self, req: Req, has_chunked_req: bool):
def add_one_req(
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
):
if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req, has_chunked_req)
@@ -448,6 +461,18 @@ class PrefillAdder:
if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN
if (
enable_hierarchical_cache
and req.last_node_global is not None
and req.last_node_global.evicted
):
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node_global, req.prefix_indices
)
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill
self.can_run_list.append(req)