[BUG FIX] add fail check when get fail in case wait complete block (#9971)
Co-authored-by: mashisong <mashisong@bytedance.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -207,26 +207,25 @@ class PrefetchOperation(StorageOperation):
|
|||||||
):
|
):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
|
||||||
self._done_flag = False
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
self._terminated_flag = False
|
||||||
self.start_time = time.monotonic()
|
self.start_time = time.monotonic()
|
||||||
|
|
||||||
super().__init__(host_indices, token_ids, last_hash)
|
super().__init__(host_indices, token_ids, last_hash)
|
||||||
|
|
||||||
def increment(self, num_tokens: int):
|
def increment(self, num_tokens: int):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._done_flag:
|
if self._terminated_flag:
|
||||||
return False
|
return False
|
||||||
self.completed_tokens += num_tokens
|
self.completed_tokens += num_tokens
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def mark_done(self):
|
def mark_terminate(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._done_flag = True
|
self._terminated_flag = True
|
||||||
|
|
||||||
def is_done(self) -> bool:
|
def is_terminated(self) -> bool:
|
||||||
return self._done_flag
|
return self._terminated_flag
|
||||||
|
|
||||||
|
|
||||||
class HiCacheController:
|
class HiCacheController:
|
||||||
@@ -628,7 +627,7 @@ class HiCacheController:
|
|||||||
return operation
|
return operation
|
||||||
|
|
||||||
def terminate_prefetch(self, operation):
|
def terminate_prefetch(self, operation):
|
||||||
operation.mark_done()
|
operation.mark_terminate()
|
||||||
return operation.completed_tokens, operation.hash_value
|
return operation.completed_tokens, operation.hash_value
|
||||||
|
|
||||||
def append_host_mem_release(self, host_indices: torch.Tensor):
|
def append_host_mem_release(self, host_indices: torch.Tensor):
|
||||||
@@ -709,6 +708,7 @@ class HiCacheController:
|
|||||||
operation.completed_tokens
|
operation.completed_tokens
|
||||||
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
||||||
):
|
):
|
||||||
|
operation.mark_terminate()
|
||||||
break # Some operations fail or operation terminated by controller
|
break # Some operations fail or operation terminated by controller
|
||||||
# release pre-allocated memory
|
# release pre-allocated memory
|
||||||
self.append_host_mem_release(
|
self.append_host_mem_release(
|
||||||
|
|||||||
@@ -482,15 +482,22 @@ class HiRadixCache(RadixCache):
|
|||||||
# unknown prefetch stop policy, just return True
|
# unknown prefetch stop policy, just return True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
operation_terminated = operation.is_terminated()
|
||||||
if self.tp_world_size > 1:
|
if self.tp_world_size > 1:
|
||||||
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
|
states = torch.tensor(
|
||||||
|
[1 - int(can_terminate), int(operation_terminated)],
|
||||||
|
dtype=torch.int,
|
||||||
|
)
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
can_terminate,
|
states,
|
||||||
op=torch.distributed.ReduceOp.MIN,
|
op=torch.distributed.ReduceOp.MAX,
|
||||||
group=self.tp_group,
|
group=self.tp_group,
|
||||||
)
|
)
|
||||||
can_terminate = bool(can_terminate.item())
|
can_terminate = states[0].item() == 0
|
||||||
|
operation_terminated = states[1].item() == 1
|
||||||
|
# the operation should be terminated if it is already terminated on any TP worker
|
||||||
|
# or it meets the termination condition on all TP workers
|
||||||
|
can_terminate = can_terminate or operation_terminated
|
||||||
return can_terminate
|
return can_terminate
|
||||||
|
|
||||||
def check_prefetch_progress(self, req_id: str) -> bool:
|
def check_prefetch_progress(self, req_id: str) -> bool:
|
||||||
@@ -517,7 +524,7 @@ class HiRadixCache(RadixCache):
|
|||||||
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
||||||
|
|
||||||
min_completed_tokens = completed_tokens
|
min_completed_tokens = completed_tokens
|
||||||
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
|
if self.tp_world_size > 1:
|
||||||
# synchrnoize TP workers to make the same update to hiradix cache
|
# synchrnoize TP workers to make the same update to hiradix cache
|
||||||
completed_tokens_tensor = torch.tensor(
|
completed_tokens_tensor = torch.tensor(
|
||||||
min_completed_tokens, dtype=torch.int
|
min_completed_tokens, dtype=torch.int
|
||||||
|
|||||||
Reference in New Issue
Block a user