[BUG FIX] add fail check when get fail in case wait complete block (#9971)
Co-authored-by: mashisong <mashisong@bytedance.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -482,15 +482,22 @@ class HiRadixCache(RadixCache):
|
||||
# unknown prefetch stop policy, just return True
|
||||
return True
|
||||
|
||||
operation_terminated = operation.is_terminated()
|
||||
if self.tp_world_size > 1:
|
||||
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
|
||||
states = torch.tensor(
|
||||
[1 - int(can_terminate), int(operation_terminated)],
|
||||
dtype=torch.int,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
can_terminate,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
states,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=self.tp_group,
|
||||
)
|
||||
can_terminate = bool(can_terminate.item())
|
||||
|
||||
can_terminate = states[0].item() == 0
|
||||
operation_terminated = states[1].item() == 1
|
||||
# the operation should be terminated if it is already terminated on any TP worker
|
||||
# or it meets the termination condition on all TP workers
|
||||
can_terminate = can_terminate or operation_terminated
|
||||
return can_terminate
|
||||
|
||||
def check_prefetch_progress(self, req_id: str) -> bool:
|
||||
@@ -517,7 +524,7 @@ class HiRadixCache(RadixCache):
|
||||
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||
|
||||
min_completed_tokens = completed_tokens
|
||||
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
|
||||
if self.tp_world_size > 1:
|
||||
# synchrnoize TP workers to make the same update to hiradix cache
|
||||
completed_tokens_tensor = torch.tensor(
|
||||
min_completed_tokens, dtype=torch.int
|
||||
|
||||
Reference in New Issue
Block a user