[HiCacheStorage] fix abort request host memory leaks (#9874)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -2403,6 +2403,9 @@ class Scheduler(
|
|||||||
# This only works for requests that have not started anything.
|
# This only works for requests that have not started anything.
|
||||||
# We still need to send something back to TokenizerManager to clean up the state.
|
# We still need to send something back to TokenizerManager to clean up the state.
|
||||||
req = self.waiting_queue.pop(i)
|
req = self.waiting_queue.pop(i)
|
||||||
|
if self.enable_hicache_storage:
|
||||||
|
# to release prefetch events associated with the request
|
||||||
|
self.tree_cache.release_aborted_request(req.rid)
|
||||||
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
||||||
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
|||||||
@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
|
|||||||
|
|
||||||
# todo: more policies for prefetch progress such as timeout
|
# todo: more policies for prefetch progress such as timeout
|
||||||
# the current policy is to prefetch with best effort and terminate when queuing is over
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
||||||
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
|
||||||
req_id
|
req_id
|
||||||
]
|
)
|
||||||
|
|
||||||
if operation.host_indices is None:
|
if operation.host_indices is None:
|
||||||
# prefetch has not been issued due to insufficient host memory
|
# prefetch has not been issued due to insufficient host memory
|
||||||
@@ -512,7 +512,6 @@ class HiRadixCache(RadixCache):
|
|||||||
host_indices[min_completed_tokens:completed_tokens]
|
host_indices[min_completed_tokens:completed_tokens]
|
||||||
)
|
)
|
||||||
last_host_node.release_host()
|
last_host_node.release_host()
|
||||||
del self.ongoing_prefetch[req_id]
|
|
||||||
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -771,3 +770,20 @@ class HiRadixCache(RadixCache):
|
|||||||
if not cur_child.evicted:
|
if not cur_child.evicted:
|
||||||
stack.append(cur_child)
|
stack.append(cur_child)
|
||||||
return ret_list
|
return ret_list
|
||||||
|
|
||||||
|
def release_aborted_request(self, rid: str):
|
||||||
|
if rid not in self.ongoing_prefetch:
|
||||||
|
return
|
||||||
|
|
||||||
|
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
|
||||||
|
rid
|
||||||
|
)
|
||||||
|
if operation.host_indices is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
|
||||||
|
if self.tp_world_size > 1:
|
||||||
|
torch.distributed.barrier(group=self.tp_group)
|
||||||
|
last_host_node.release_host()
|
||||||
|
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
||||||
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user