Hierarchical Caching Refactoring and Fixing TP issue (#4082)
This commit is contained in:
@@ -315,6 +315,7 @@ class Req:
|
||||
# The relative logprob_start_len in an extend batch
|
||||
self.extend_logprob_start_len = 0
|
||||
self.last_node = None
|
||||
self.last_node_global = None
|
||||
|
||||
# Whether or not if it is chunked. It increments whenever
|
||||
# it is chunked, and decrement whenever chunked request is
|
||||
@@ -389,13 +390,24 @@ class Req:
|
||||
# Whether request reached finished condition
|
||||
return self.finished_reason is not None
|
||||
|
||||
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
||||
def init_next_round_input(
|
||||
self,
|
||||
tree_cache: Optional[BasePrefixCache] = None,
|
||||
enable_hierarchical_cache=False,
|
||||
):
|
||||
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||
if tree_cache is not None:
|
||||
# tree cache is None if the prefix is not computed with tree cache.
|
||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||
)
|
||||
if enable_hierarchical_cache:
|
||||
self.prefix_indices, self.last_node, self.last_node_global = (
|
||||
tree_cache.match_prefix(
|
||||
key=self.adjust_max_prefix_ids(), include_evicted=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||
)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
|
||||
Reference in New Issue
Block a user