Reduce overhead when fork(1) (#375)
This commit is contained in:
@@ -256,9 +256,15 @@ class StreamExecutor:
|
|||||||
ret = self.meta_info.get(name, None)
|
ret = self.meta_info.get(name, None)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
|
def fork(
|
||||||
self.submit(SglCommitLazy())
|
self,
|
||||||
self.sync()
|
number: int,
|
||||||
|
position_ids_offset: Optional[List[int]] = None,
|
||||||
|
copy: bool = False,
|
||||||
|
):
|
||||||
|
if number > 1 or copy:
|
||||||
|
self.submit(SglCommitLazy())
|
||||||
|
self.sync()
|
||||||
|
|
||||||
number = int(number)
|
number = int(number)
|
||||||
|
|
||||||
@@ -641,15 +647,20 @@ class ProgramState:
|
|||||||
yield
|
yield
|
||||||
self.stream_executor.submit(SglVarScopeEnd(name))
|
self.stream_executor.submit(SglVarScopeEnd(name))
|
||||||
|
|
||||||
def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None):
|
def fork(
|
||||||
stream_executors = self.stream_executor.fork(number, position_ids_offset)
|
self,
|
||||||
|
number: int = 1,
|
||||||
|
position_ids_offset: Optional[List[int]] = None,
|
||||||
|
copy: bool = False,
|
||||||
|
):
|
||||||
|
stream_executors = self.stream_executor.fork(number, position_ids_offset, copy)
|
||||||
states = [ProgramState(x) for x in stream_executors]
|
states = [ProgramState(x) for x in stream_executors]
|
||||||
state_group = ProgramStateGroup(states, self)
|
state_group = ProgramStateGroup(states, self)
|
||||||
return state_group
|
return state_group
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def copy(self, position_ids_offset: Optional[List[int]] = None):
|
def copy(self, position_ids_offset: Optional[List[int]] = None):
|
||||||
state_group = self.fork(1, position_ids_offset)
|
state_group = self.fork(1, position_ids_offset, True)
|
||||||
try:
|
try:
|
||||||
yield state_group[0]
|
yield state_group[0]
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -179,7 +179,9 @@ class RadixCache:
|
|||||||
|
|
||||||
def _print_helper(self, node, indent):
|
def _print_helper(self, node, indent):
|
||||||
for _, child in node.children.items():
|
for _, child in node.children.items():
|
||||||
print(" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}")
|
print(
|
||||||
|
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
|
||||||
|
)
|
||||||
self._print_helper(child, indent=indent + 2)
|
self._print_helper(child, indent=indent + 2)
|
||||||
|
|
||||||
def _delete_leaf(self, node):
|
def _delete_leaf(self, node):
|
||||||
|
|||||||
Reference in New Issue
Block a user