fix(cache): move ongoing_prefetch pop after validation to prevent leak (#9927)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
|
||||
|
||||
# todo: more policies for prefetch progress such as timeout
|
||||
# the current policy is to prefetch with best effort and terminate when queuing is over
|
||||
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
|
||||
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
||||
req_id
|
||||
)
|
||||
]
|
||||
|
||||
if operation.host_indices is None:
|
||||
# prefetch has not been issued due to insufficient host memory
|
||||
@@ -512,6 +512,7 @@ class HiRadixCache(RadixCache):
|
||||
host_indices[min_completed_tokens:completed_tokens]
|
||||
)
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[req_id]
|
||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||
|
||||
return True
|
||||
@@ -775,9 +776,7 @@ class HiRadixCache(RadixCache):
|
||||
if rid not in self.ongoing_prefetch:
|
||||
return
|
||||
|
||||
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
|
||||
rid
|
||||
)
|
||||
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
|
||||
if operation.host_indices is None:
|
||||
return
|
||||
|
||||
@@ -785,5 +784,6 @@ class HiRadixCache(RadixCache):
|
||||
if self.tp_world_size > 1:
|
||||
torch.distributed.barrier(group=self.tp_group)
|
||||
last_host_node.release_host()
|
||||
del self.ongoing_prefetch[rid]
|
||||
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||
|
||||
Reference in New Issue
Block a user