Support overlapping two batches (#4068)
This commit is contained in:
@@ -12,7 +12,7 @@ if _ENABLE_PROFILE:
|
||||
|
||||
|
||||
def execute_operations(inputs, operations):
|
||||
stages = _convert_operations_to_stages(decorate_operations(operations))
|
||||
stages = _convert_operations_to_stages(operations)
|
||||
executor = _StageExecutor("primary", stages, inputs=inputs)
|
||||
for _ in range(executor.num_stages):
|
||||
executor.next()
|
||||
@@ -20,6 +20,37 @@ def execute_operations(inputs, operations):
|
||||
return executor.output
|
||||
|
||||
|
||||
def execute_overlapped_operations(
|
||||
inputs_arr: Sequence,
|
||||
operations_arr: Sequence,
|
||||
delta_stages: Sequence[int],
|
||||
) -> Sequence:
|
||||
# Make it explicit for clarity; if we need multi-batch overlap, this can be generalized
|
||||
inputs_a, inputs_b = inputs_arr
|
||||
operations_a, operations_b = operations_arr
|
||||
delta_stage_a, delta_stage_b = delta_stages
|
||||
assert delta_stage_a == 0
|
||||
delta_stage = delta_stage_b
|
||||
|
||||
stages_a = _convert_operations_to_stages(operations_a)
|
||||
stages_b = _convert_operations_to_stages(operations_b)
|
||||
executor_a = _StageExecutor("a", stages_a, inputs=inputs_a)
|
||||
executor_b = _StageExecutor("b", stages_b, inputs=inputs_b)
|
||||
|
||||
for _ in range(delta_stage):
|
||||
executor_a.next()
|
||||
|
||||
for _ in range(executor_a.num_stages - delta_stage):
|
||||
executor_a.next()
|
||||
executor_b.next()
|
||||
|
||||
for _ in range(delta_stage):
|
||||
executor_b.next()
|
||||
|
||||
assert executor_a.done and executor_b.done
|
||||
return [executor_a.output, executor_b.output]
|
||||
|
||||
|
||||
class YieldOperation:
|
||||
pass
|
||||
|
||||
@@ -109,6 +140,9 @@ class _StateDict:
|
||||
for k, v in values.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def get(self, item):
|
||||
return self._data.get(item)
|
||||
|
||||
def clear(self, expect_keys: Sequence[str]):
|
||||
if set(self._data.keys()) != set(expect_keys):
|
||||
raise Exception(
|
||||
@@ -119,6 +153,7 @@ class _StateDict:
|
||||
|
||||
|
||||
def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
|
||||
operations = _decorate_operations(operations)
|
||||
operation_chunks = list(
|
||||
_chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
|
||||
)
|
||||
@@ -140,7 +175,7 @@ def _chunk_by_separator(
|
||||
yield pending_items
|
||||
|
||||
|
||||
def decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
|
||||
def _decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
|
||||
return [_decorate_operation(op, debug_name_prefix) for op in operations]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user