Hierarchical Caching Refactoring and Fixing TP issue (#4082)
This commit is contained in:
@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
|
||||
class SchedulePolicy:
|
||||
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
|
||||
|
||||
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
policy: str,
|
||||
tree_cache: BasePrefixCache,
|
||||
enable_hierarchical_cache: bool = False,
|
||||
):
|
||||
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
|
||||
self.tree_cache = tree_cache
|
||||
self.enable_hierarchical_cache = enable_hierarchical_cache
|
||||
|
||||
# It is used to find the matching prefix for in-batch prefix caching.
|
||||
self.waiting_queue_radix_tree = RadixCache(
|
||||
@@ -149,9 +155,14 @@ class SchedulePolicy:
|
||||
prefix_ids = r.adjust_max_prefix_ids()
|
||||
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
)
|
||||
if self.enable_hierarchical_cache:
|
||||
r.prefix_indices, r.last_node, r.last_node_global = (
|
||||
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
|
||||
)
|
||||
else:
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=prefix_ids
|
||||
)
|
||||
|
||||
# NOTE(sang): This logic is for in-batch prefix caching;
|
||||
# If there are more than 1 request that have small matching prefix from
|
||||
@@ -428,7 +439,9 @@ class PrefillAdder:
|
||||
|
||||
return self.budget_state()
|
||||
|
||||
def add_one_req(self, req: Req, has_chunked_req: bool):
|
||||
def add_one_req(
|
||||
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
|
||||
):
|
||||
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||
|
||||
@@ -448,6 +461,18 @@ class PrefillAdder:
|
||||
if total_tokens > self.rem_total_tokens:
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if (
|
||||
enable_hierarchical_cache
|
||||
and req.last_node_global is not None
|
||||
and req.last_node_global.evicted
|
||||
):
|
||||
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
||||
req.last_node_global, req.prefix_indices
|
||||
)
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
input_tokens = req.extend_input_len
|
||||
prefix_len = len(req.prefix_indices)
|
||||
|
||||
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
|
||||
Reference in New Issue
Block a user