Simple prefetch policy (#8692)

This commit is contained in:
pansicheng
2025-08-08 17:09:28 +08:00
committed by GitHub
parent 7490e3f67d
commit e2fd2b9c7e
6 changed files with 148 additions and 36 deletions

View File

@@ -2,11 +2,12 @@ import heapq
import logging
import threading
import time
from queue import Queue
from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import (
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
hicache_io_backend: str,
hicache_mem_layout: str,
hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
):
if hicache_io_backend == "direct":
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
prefetch_threshold=self.prefetch_threshold,
)
self.prefetch_stop_policy = hicache_storage_prefetch_policy
# todo: customizable storage prefetch timeout
self.prefetch_timeout = 3 # seconds
logger.info(
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
)
# record the nodes with ongoing write through
self.ongoing_write_through = {}
# record the node segments with ongoing load back
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
for _ in range(queue_size.item()):
req_id = self.cache_controller.prefetch_revoke_queue.get()
if req_id in self.ongoing_prefetch:
last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
else:
# the revoked operation already got terminated
pass
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
host_node.release_host()
del self.ongoing_backup[ack_id]
def check_prefetch_progress(self, req_id: str):
def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True
if self.prefetch_stop_policy == "best_effort":
return can_terminate
completed = (
operation.completed_tokens == len(operation.hash_value) * self.page_size
)
if self.prefetch_stop_policy == "wait_complete":
can_terminate = completed
elif self.prefetch_stop_policy == "timeout":
can_terminate = completed or (
time.monotonic() - operation.start_time > self.prefetch_timeout
)
else:
# unknown prefetch stop policy, just return True
return True
if self.tp_world_size > 1:
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
torch.distributed.all_reduce(
can_terminate,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
can_terminate = bool(can_terminate.item())
return can_terminate
def check_prefetch_progress(self, req_id: str) -> bool:
if req_id not in self.ongoing_prefetch:
# there is no ongoing prefetch for this request or it has been revoked
return
return True
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
req_id
]
if not self.can_terminate_prefetch(operation):
return False
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
return True
def match_prefix(self, key: List[int], **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
host_indices,
operation,
)
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
node.last_access_time = time.monotonic()