diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4bf76f78b..af24f941c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2403,6 +2403,9 @@ class Scheduler( # This only works for requests that have not started anything. # We still need to send something back to TokenizerManager to clean up the state. req = self.waiting_queue.pop(i) + if self.enable_hicache_storage: + # to release prefetch events associated with the request + self.tree_cache.release_aborted_request(req.rid) self.send_to_tokenizer.send_pyobj(AbortReq(req.rid)) # For disaggregation decode mode, the request in the waiting queue has KV cache allocated. if self.disaggregation_mode == DisaggregationMode.DECODE: diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 2bd231ae6..ff4564613 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -468,9 +468,9 @@ class HiRadixCache(RadixCache): # todo: more policies for prefetch progress such as timeout # the current policy is to prefetch with best effort and terminate when queuing is over - last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ + last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop( req_id - ] + ) if operation.host_indices is None: # prefetch has not been issued due to insufficient host memory @@ -512,7 +512,6 @@ class HiRadixCache(RadixCache): host_indices[min_completed_tokens:completed_tokens] ) last_host_node.release_host() - del self.ongoing_prefetch[req_id] self.cache_controller.prefetch_tokens_occupied -= len(token_ids) return True @@ -771,3 +770,20 @@ class HiRadixCache(RadixCache): if not cur_child.evicted: stack.append(cur_child) return ret_list + + def release_aborted_request(self, rid: str): + if rid not in self.ongoing_prefetch: + return + + last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop( + rid + ) + if operation.host_indices is None: + return + + completed_tokens, _ = self.cache_controller.terminate_prefetch(operation) + if self.tp_world_size > 1: + torch.distributed.barrier(group=self.tp_group) + last_host_node.release_host() + self.cache_controller.append_host_mem_release(host_indices[:completed_tokens]) + self.cache_controller.prefetch_tokens_occupied -= len(token_ids)