[HiCacheStorage] fix abort request host memory leaks (#9874)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
huangtingwei
2025-09-02 09:59:29 +08:00
committed by GitHub
parent 9db8025376
commit cb9e0e4180
2 changed files with 22 additions and 3 deletions

View File

@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
req_id
]
)
if operation.host_indices is None:
# prefetch has not been issued due to insufficient host memory
@@ -512,7 +512,6 @@ class HiRadixCache(RadixCache):
host_indices[min_completed_tokens:completed_tokens]
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
return True
@@ -771,3 +770,20 @@ class HiRadixCache(RadixCache):
if not cur_child.evicted:
stack.append(cur_child)
return ret_list
def release_aborted_request(self, rid: str):
if rid not in self.ongoing_prefetch:
return
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
rid
)
if operation.host_indices is None:
return
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
if self.tp_world_size > 1:
torch.distributed.barrier(group=self.tp_group)
last_host_node.release_host()
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)