Hierarchical Caching Refactoring and Fixing TP issue (#4082)

This commit is contained in:
Zhiqiang Xie
2025-03-12 11:22:35 -07:00
committed by GitHub
parent 01090e8ac3
commit 10b544ae9b
6 changed files with 194 additions and 56 deletions

View File

@@ -315,6 +315,7 @@ class Req:
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
self.last_node = None
self.last_node_global = None
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
@@ -389,13 +390,24 @@ class Req:
# Whether request reached finished condition
return self.finished_reason is not None
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
def init_next_round_input(
self,
tree_cache: Optional[BasePrefixCache] = None,
enable_hierarchical_cache=False,
):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache.
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
if enable_hierarchical_cache:
self.prefix_indices, self.last_node, self.last_node_global = (
tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(), include_evicted=True
)
)
else:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):