From 33467c05a4290e460b1a20b6b698a033752ea7c3 Mon Sep 17 00:00:00 2001 From: Shisong Ma <31835442+mss1213@users.noreply.github.com> Date: Mon, 8 Sep 2025 09:34:04 +0800 Subject: [PATCH] [BUG FIX] add fail check when get fail in case wait complete block (#9971) Co-authored-by: mashisong Co-authored-by: Zhiqiang Xie --- .../sglang/srt/managers/cache_controller.py | 16 ++++++++-------- python/sglang/srt/mem_cache/hiradix_cache.py | 19 +++++++++++++------ 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 6bc7bd8f1..6846022f9 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -207,26 +207,25 @@ class PrefetchOperation(StorageOperation): ): self.request_id = request_id - self._done_flag = False self._lock = threading.Lock() - + self._terminated_flag = False self.start_time = time.monotonic() super().__init__(host_indices, token_ids, last_hash) def increment(self, num_tokens: int): with self._lock: - if self._done_flag: + if self._terminated_flag: return False self.completed_tokens += num_tokens return True - def mark_done(self): + def mark_terminate(self): with self._lock: - self._done_flag = True + self._terminated_flag = True - def is_done(self) -> bool: - return self._done_flag + def is_terminated(self) -> bool: + return self._terminated_flag class HiCacheController: @@ -628,7 +627,7 @@ class HiCacheController: return operation def terminate_prefetch(self, operation): - operation.mark_done() + operation.mark_terminate() return operation.completed_tokens, operation.hash_value def append_host_mem_release(self, host_indices: torch.Tensor): @@ -709,6 +708,7 @@ class HiCacheController: operation.completed_tokens != prev_completed_tokens + len(batch_hashes) * self.page_size ): + operation.mark_terminate() break # Some operations fail or operation terminated by controller # release pre-allocated memory self.append_host_mem_release( diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index a861e233e..5883c1f15 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -482,15 +482,22 @@ class HiRadixCache(RadixCache): # unknown prefetch stop policy, just return True return True + operation_terminated = operation.is_terminated() if self.tp_world_size > 1: - can_terminate = torch.tensor(can_terminate, dtype=torch.int) + states = torch.tensor( + [1 - int(can_terminate), int(operation_terminated)], + dtype=torch.int, + ) torch.distributed.all_reduce( - can_terminate, - op=torch.distributed.ReduceOp.MIN, + states, + op=torch.distributed.ReduceOp.MAX, group=self.tp_group, ) - can_terminate = bool(can_terminate.item()) - + can_terminate = states[0].item() == 0 + operation_terminated = states[1].item() == 1 + # the operation should be terminated if it is already terminated on any TP worker + # or it meets the termination condition on all TP workers + can_terminate = can_terminate or operation_terminated return can_terminate def check_prefetch_progress(self, req_id: str) -> bool: @@ -517,7 +524,7 @@ class HiRadixCache(RadixCache): logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") min_completed_tokens = completed_tokens - if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete": + if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to hiradix cache completed_tokens_tensor = torch.tensor( min_completed_tokens, dtype=torch.int