diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index ff4564613..5f78ee111 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -468,9 +468,9 @@ class HiRadixCache(RadixCache): # todo: more policies for prefetch progress such as timeout # the current policy is to prefetch with best effort and terminate when queuing is over - last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop( + last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ req_id - ) + ] if operation.host_indices is None: # prefetch has not been issued due to insufficient host memory @@ -512,6 +512,7 @@ class HiRadixCache(RadixCache): host_indices[min_completed_tokens:completed_tokens] ) last_host_node.release_host() + del self.ongoing_prefetch[req_id] self.cache_controller.prefetch_tokens_occupied -= len(token_ids) return True @@ -775,9 +776,7 @@ class HiRadixCache(RadixCache): if rid not in self.ongoing_prefetch: return - last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop( - rid - ) + last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid] if operation.host_indices is None: return @@ -785,5 +784,6 @@ class HiRadixCache(RadixCache): if self.tp_world_size > 1: torch.distributed.barrier(group=self.tp_group) last_host_node.release_host() + del self.ongoing_prefetch[rid] self.cache_controller.append_host_mem_release(host_indices[:completed_tokens]) self.cache_controller.prefetch_tokens_occupied -= len(token_ids)